From 1daba748a6a6d65794d8aa5969e6f1620203e1b2 Mon Sep 17 00:00:00 2001 From: Matt Whiteway Date: Wed, 5 Jul 2023 15:49:51 -0400 Subject: [PATCH] linting + docs (#102) * initial contrib docs * isort * flake8 * isort * fixes --- docs/README.md | 2 + docs/contributing.md | 43 ++++++ .../apps/labeled_frame_diagnostics.py | 61 +++++---- lightning_pose/apps/plots.py | 116 +++++++++++------ lightning_pose/apps/utils.py | 85 +++++++----- lightning_pose/apps/video_diagnostics.py | 122 ++++++++++++------ lightning_pose/callbacks.py | 2 +- lightning_pose/data/dali.py | 41 +++--- lightning_pose/data/datamodules.py | 28 ++-- lightning_pose/data/datasets.py | 33 +++-- lightning_pose/data/utils.py | 92 +++++++------ lightning_pose/losses/factory.py | 3 +- lightning_pose/losses/helpers.py | 5 +- lightning_pose/losses/losses.py | 48 +++---- lightning_pose/metrics.py | 6 +- .../models/backbones/torchvision.py | 13 +- lightning_pose/models/backbones/vits.py | 3 +- lightning_pose/models/base.py | 9 +- lightning_pose/models/heatmap_tracker.py | 29 ++--- .../models/heatmap_tracker_mhcrnn.py | 15 ++- lightning_pose/models/regression_tracker.py | 9 +- lightning_pose/utils/__init__.py | 2 +- lightning_pose/utils/fiftyone.py | 17 +-- lightning_pose/utils/io.py | 13 +- lightning_pose/utils/pca.py | 27 ++-- lightning_pose/utils/predictions.py | 41 +++--- lightning_pose/utils/scripts.py | 61 ++++----- lightning_pose/utils/tests.py | 1 + scripts/converters/dlc2lp.py | 4 +- scripts/create_fiftyone_dataset.py | 8 +- scripts/predict_new_vids.py | 17 ++- scripts/train_hydra.py | 13 +- setup.cfg | 16 +++ setup.py | 7 +- tests/conftest.py | 17 +-- tests/data/test_dali.py | 5 +- tests/data/test_datamodules.py | 3 +- tests/data/test_datasets.py | 1 - tests/data/test_utils.py | 13 +- tests/losses/test_helpers.py | 1 - tests/losses/test_losses.py | 6 +- tests/models/test_base.py | 2 - tests/models/test_heatmap_tracker.py | 1 - tests/models/test_heatmap_tracker_mhcrnn.py | 1 - tests/models/test_regression_tracker.py | 1 - tests/utils/test_pca.py | 9 +- 46 files changed, 625 insertions(+), 427 deletions(-) create mode 100644 docs/contributing.md create mode 100644 setup.cfg diff --git a/docs/README.md b/docs/README.md index a28af79c..e0477be5 100644 --- a/docs/README.md +++ b/docs/README.md @@ -31,3 +31,5 @@ where to store the model code, how to make it visible to users, how to provide hyperparameters for model construction, and how to connect it to data loaders and losses + +* [General contributing guidelines](contributing.md): pull requests, testing, linting, etc. diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100644 index 00000000..5ea71bb7 --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,43 @@ +# Contributing + +We welcome community contributions to the Lightning Pose repo! +If you have found a bug or would like to request a minor change, please +[open an issue](https://github.com/danbider/lightning-pose/issues). + +In order to contribute code to the repo, please follow the steps below. + +You will also need to install the following dev packages: +```bash +pip install flake8 isort +``` + +### Create a pull request +Please fork the Lightning Pose repo, make your changes, and then +[open a pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) +from your fork. Please read through the rest of this document before submitting the request. + +### Linting +Linters automatically find (and sometimes fix) formatting issues in the code. We use two, which +are run from the command line in the Lightning Pose repo: + +* `flake8`: warns of syntax errors, possible bugs, stylistic errors, etc. Please fix these! +```bash +flake8 . +``` + +* `isort`: automatically sorts import statements +```bash +isort . +``` + +### Testing +We currently do not have a continuous integration (CI) setup for the Lightning Pose repo due to its +reliance on GPUs (and the relative expense of CI services that provide GPU machines for testing). +Therefore, it is imperative that you run the unit tests yourself and verify that all tests have +passed before submitting your request (and upon each new push to that request). + +To run the tests locally, you must have access to a GPU. Navigate to the Lightning Pose directory +and simply run +```bash +pytest +``` diff --git a/lightning_pose/apps/labeled_frame_diagnostics.py b/lightning_pose/apps/labeled_frame_diagnostics.py index 8f2b9891..1905f904 100644 --- a/lightning_pose/apps/labeled_frame_diagnostics.py +++ b/lightning_pose/apps/labeled_frame_diagnostics.py @@ -2,26 +2,30 @@ Refer to apps.md for information on how to use this file. -streamlit run labeled_frame_diagnostics.py -- --model_dir "/home/zeus/content/Pose-app/data/demo/models" +streamlit run labeled_frame_diagnostics.py -- --model_dir ".../Pose-app/data/demo/models" """ import argparse import copy -import numpy as np -import streamlit as st -import seaborn as sns -import pandas as pd +import os +from collections import defaultdict from pathlib import Path + import numpy as np -from collections import defaultdict -import os +import pandas as pd +import seaborn as sns +import streamlit as st -from lightning_pose.apps.utils import build_precomputed_metrics_df, get_df_box, get_df_scatter -from lightning_pose.apps.utils import update_labeled_file_list -from lightning_pose.apps.utils import get_model_folders, get_model_folders_vis -from lightning_pose.apps.plots import make_seaborn_catplot, make_plotly_scatterplot, get_y_label -from lightning_pose.apps.plots import make_plotly_catplot +from lightning_pose.apps.plots import get_y_label, make_plotly_catplot, make_plotly_scatterplot +from lightning_pose.apps.utils import ( + build_precomputed_metrics_df, + get_df_box, + get_df_scatter, + get_model_folders, + get_model_folders_vis, + update_labeled_file_list, +) # catplot_options = ["boxen", "box", "bar", "violin", "strip"] # for seaborn catplot_options = ["box", "violin", "strip"] # for plotly @@ -31,7 +35,6 @@ def run(): - args = parser.parse_args() st.title("Labeled Frame Diagnostics") @@ -48,7 +51,7 @@ def run(): # add a multiselect that shows existing model folders, and allows the user to de-select models # assume we have args.model_dir and we search two levels down for model folders model_folders = get_model_folders(args.model_dir) - + # get the last two levels of each path to be presented to user model_folders_vis = get_model_folders_vis(model_folders) @@ -56,7 +59,7 @@ def run(): # append this to full path selected_models = ["/" + os.path.join(args.model_dir, f) for f in selected_models_vis] - + # search for prediction files in the selected model folders prediction_files = update_labeled_file_list(selected_models, use_ood=args.use_ood) @@ -65,7 +68,6 @@ def run(): model_names = copy.copy(selected_models_vis) if len(prediction_files) > 0: # otherwise don't try to proceed - # --------------------------------------------------- # load data # --------------------------------------------------- @@ -107,9 +109,11 @@ def run(): # concat dataframes, collapsing hierarchy and making df fatter. keypoint_names = list( - [c[0] for c in dframes_metrics[new_names[0]]["confidence"].columns[1::3]]) + [c[0] for c in dframes_metrics[new_names[0]]["confidence"].columns[1::3]] + ) df_metrics = build_precomputed_metrics_df( - dframes=dframes_metrics, keypoint_names=keypoint_names) + dframes=dframes_metrics, keypoint_names=keypoint_names + ) metric_options = list(df_metrics.keys()) # ---------------------------------------------------' @@ -122,7 +126,8 @@ def run(): with col0: # choose from individual keypoints, their mean, or all at once keypoint_to_plot = st.selectbox( - "Keypoint:", ["mean", "ALL", *keypoint_names], key="keypoint") + "Keypoint:", ["mean", "ALL", *keypoint_names], key="keypoint" + ) with col1: # choose which metric to plot @@ -142,7 +147,6 @@ def run(): # --------------------------------------------------- with sup_col00: - st.header("Compare multiple models") # enumerate plotting options @@ -158,7 +162,9 @@ def run(): plot_scale = st.radio("Y-axis scale", scale_options, horizontal=True) # filter data - df_metrics_filt = df_metrics[metric_to_plot][df_metrics[metric_to_plot].set == data_type] + df_metrics_filt = df_metrics[metric_to_plot][ + df_metrics[metric_to_plot].set == data_type + ] n_frames_per_dtype = df_metrics_filt.shape[0] // len(selected_models) # plot data @@ -180,15 +186,19 @@ def run(): else: top = 0.9 fig_box.fig.subplots_adjust(top=top) - fig_box.fig.suptitle("All keypoints (%i %s frames)" % (n_frames_per_dtype, data_type)) + fig_box.fig.suptitle( + f"All keypoints ({n_frames_per_dtype} {data_type} frames)" + ) st.pyplot(fig_box) else: - st.markdown("###") fig_box = make_plotly_catplot( - x="model_name", y=keypoint_to_plot, - data=df_metrics_filt[df_metrics_filt[keypoint_to_plot] > int(plot_epsilon)], + x="model_name", + y=keypoint_to_plot, + data=df_metrics_filt[ + df_metrics_filt[keypoint_to_plot] > int(plot_epsilon) + ], x_label="Model name", y_label=y_label, title=title, @@ -208,7 +218,6 @@ def run(): # scatterplots # --------------------------------------------------- with sup_col01: - st.header("Compare two models") col6, col7, col8 = st.columns(3) diff --git a/lightning_pose/apps/plots.py b/lightning_pose/apps/plots.py index 13df1198..b496d4dd 100644 --- a/lightning_pose/apps/plots.py +++ b/lightning_pose/apps/plots.py @@ -1,11 +1,11 @@ """A collection of visualizations for various pose estimation performance metrics.""" -from matplotlib import pyplot as plt import numpy as np import plotly.express as px import plotly.graph_objects as go -from plotly.subplots import make_subplots import seaborn as sns +from matplotlib import pyplot as plt +from plotly.subplots import make_subplots pix_error_key = "pixel error" conf_error_key = "confidence" @@ -15,10 +15,14 @@ def get_y_label(to_compute: str) -> str: - if to_compute == 'rmse' or to_compute == "pixel_error" or to_compute == "pixel error": - return 'Pixel Error' - if to_compute == 'temporal_norm' or to_compute == 'temporal norm': - return 'Temporal norm (pix.)' + if ( + to_compute == "rmse" + or to_compute == "pixel_error" + or to_compute == "pixel error" + ): + return "Pixel Error" + if to_compute == "temporal_norm" or to_compute == "temporal norm": + return "Temporal norm (pix.)" elif to_compute == "pca_multiview" or to_compute == "pca multiview": return "Multiview PCA\nrecon error (pix.)" elif to_compute == "pca_singleview" or to_compute == "pca singleview": @@ -28,7 +32,8 @@ def get_y_label(to_compute: str) -> str: def make_seaborn_catplot( - x, y, data, x_label, y_label, title, log_y=False, plot_type="box", figsize=(5, 5)): + x, y, data, x_label, y_label, title, log_y=False, plot_type="box", figsize=(5, 5) +): sns.set_context("paper") fig = plt.figure(figsize=figsize) if plot_type == "box": @@ -53,9 +58,17 @@ def make_seaborn_catplot( def make_plotly_catplot( - x, y, data, x_label, y_label, title, log_y=False, plot_type="box", - fig_height=500, fig_width=500, - ): + x, + y, + data, + x_label, + y_label, + title, + log_y=False, + plot_type="box", + fig_height=500, + fig_width=500, +): if plot_type == "box" or plot_type == "boxen": fig = px.box(data, x=x, y=y, log_y=log_y) elif plot_type == "violin": @@ -66,23 +79,37 @@ def make_plotly_catplot( # fig = px.bar(data, x=x, y=y, log_y=log_y) elif plot_type == "hist": fig = px.histogram( - data, x=x, color="model_name", marginal="rug", barmode="overlay", + data, + x=x, + color="model_name", + marginal="rug", + barmode="overlay", ) fig.update_layout( - yaxis_title=y_label, xaxis_title=x_label, title=title, - height=fig_height, width=fig_width, + yaxis_title=y_label, + xaxis_title=x_label, + title=title, + height=fig_height, + width=fig_width, ) return fig def make_plotly_scatterplot( - model_0, model_1, df, metric_name, title, - axes_scale="linear", - facet_col=None, n_cols=0, opacity=0.5, hover_data=None, - fig_height=500, fig_width=500, + model_0, + model_1, + df, + metric_name, + title, + axes_scale="linear", + facet_col=None, + n_cols=0, + opacity=0.5, + hover_data=None, + fig_height=500, + fig_width=500, ): - xlabel = "%s
(%s)" % (metric_name, model_0) ylabel = "%s
(%s)" % (metric_name, model_1) @@ -90,9 +117,12 @@ def make_plotly_scatterplot( fig_scatter = px.scatter( df, - x=model_0, y=model_1, - facet_col=facet_col, facet_col_wrap=n_cols, - log_x=log_scatter, log_y=log_scatter, + x=model_0, + y=model_1, + facet_col=facet_col, + facet_col_wrap=n_cols, + log_x=log_scatter, + log_y=log_scatter, opacity=opacity, hover_data=hover_data, # trendline="ols", @@ -106,13 +136,12 @@ def make_plotly_scatterplot( trace.update(legendgroup="trendline", showlegend=False) fig_scatter.add_trace(trace, row="all", col="all", exclude_empty_subplots=True) fig_scatter.update_layout(title=title, width=fig_width, height=fig_height) - fig_scatter.update_traces(marker={'size': 5}) + fig_scatter.update_traces(marker={"size": 5}) return fig_scatter def plot_precomputed_traces(df_metrics, df_traces, cols): - # ------------------------------------------------------------- # setup # ------------------------------------------------------------- @@ -133,7 +162,8 @@ def plot_precomputed_traces(df_metrics, df_traces, cols): row_heights.insert(0, 0.75) fig_traces = make_subplots( - rows=rows, cols=1, + rows=rows, + cols=1, shared_xaxes=True, x_title="Frame number", row_heights=row_heights, @@ -160,19 +190,22 @@ def plot_precomputed_traces(df_metrics, df_traces, cols): go.Scatter( name=col, x=np.arange(df_traces.shape[0]), - y=df_metrics[error_key][kp][df_metrics[error_key].model_name == model], - mode='lines', + y=df_metrics[error_key][kp][ + df_metrics[error_key].model_name == model + ], + mode="lines", line=dict(color=colors[c]), showlegend=False, ), - row=row, col=1 + row=row, + col=1, ) if error_key == temp_norm_error_key: - yaxis_labels['yaxis%i' % row] = "temporal
norm" + yaxis_labels["yaxis%i" % row] = "temporal
norm" elif error_key == pcamv_error_key: - yaxis_labels['yaxis%i' % row] = "pca multi
error" + yaxis_labels["yaxis%i" % row] = "pca multi
error" elif error_key == pcasv_error_key: - yaxis_labels['yaxis%i' % row] = "pca single
error" + yaxis_labels["yaxis%i" % row] = "pca single
error" row += 1 # ------------------------------------------------------------- @@ -181,7 +214,9 @@ def plot_precomputed_traces(df_metrics, df_traces, cols): for coord in ["x", "y"]: for c, col in enumerate(cols): pieces = col.split("_%s_" % coordinate) - assert len(pieces) == 2 # otherwise "_[x/y]_" appears in keypoint or model name :( + assert ( + len(pieces) == 2 + ) # otherwise "_[x/y]_" appears in keypoint or model name :( kp = pieces[0] model = pieces[1] new_col = col.replace("_%s_" % coordinate, "_%s_" % coord) @@ -190,13 +225,14 @@ def plot_precomputed_traces(df_metrics, df_traces, cols): name=model, x=np.arange(df_traces.shape[0]), y=df_traces[new_col], - mode='lines', + mode="lines", line=dict(color=colors[c]), showlegend=False if coord == "x" else True, ), - row=row, col=1 + row=row, + col=1, ) - yaxis_labels['yaxis%i' % row] = "%s coordinate" % coord + yaxis_labels["yaxis%i" % row] = "%s coordinate" % coord row += 1 # ------------------------------------------------------------- @@ -209,13 +245,14 @@ def plot_precomputed_traces(df_metrics, df_traces, cols): name=col_l, x=np.arange(df_traces.shape[0]), y=df_traces[col_l], - mode='lines', + mode="lines", line=dict(color=colors[c]), showlegend=False, ), - row=row, col=1 + row=row, + col=1, ) - yaxis_labels['yaxis%i' % row] = "confidence" + yaxis_labels["yaxis%i" % row] = "confidence" row += 1 # ------------------------------------------------------------- @@ -224,8 +261,9 @@ def plot_precomputed_traces(df_metrics, df_traces, cols): for k, v in yaxis_labels.items(): fig_traces["layout"][k]["title"] = v fig_traces.update_layout( - width=800, height=np.sum(row_heights) * 125, - title_text="Timeseries of %s" % keypoint + width=800, + height=np.sum(row_heights) * 125, + title_text="Timeseries of %s" % keypoint, ) return fig_traces diff --git a/lightning_pose/apps/utils.py b/lightning_pose/apps/utils.py index eccb06b2..d5f80dce 100644 --- a/lightning_pose/apps/utils.py +++ b/lightning_pose/apps/utils.py @@ -1,12 +1,12 @@ """Utility functions for streamlit apps.""" -import numpy as np -import pandas as pd -from pathlib import Path -import streamlit as st import os -from typing import List, Dict, Tuple, Optional from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import pandas as pd +import streamlit as st pix_error_key = "pixel error" conf_error_key = "confidence" @@ -21,14 +21,17 @@ def update_labeled_file_list(model_preds_folders: list, use_ood: bool = False): for model_pred_folder in model_preds_folders: # pull labeled results from each model folder # wrap in Path so that it looks like an UploadedFile object - model_preds = [f for f in os.listdir(model_pred_folder) - if os.path.isfile(os.path.join(model_pred_folder, f))] + model_preds = [ + f + for f in os.listdir(model_pred_folder) + if os.path.isfile(os.path.join(model_pred_folder, f)) + ] ret_files = [] for file in model_preds: - if 'predictions' in file: - if 'new' not in file and not use_ood: + if "predictions" in file: + if "new" not in file and not use_ood: ret_files.append(Path(file)) - elif 'new' in file and use_ood: + elif "new" in file and use_ood: ret_files.append(Path(file)) else: pass @@ -42,8 +45,11 @@ def update_vid_metric_files_list(video: str, model_preds_folders: list): for model_preds_folder in model_preds_folders: # pull each prediction file associated with a particular video # wrap in Path so that it looks like an UploadedFile object - model_preds = [f for f in os.listdir(os.path.join(model_preds_folder, 'video_preds')) - if os.path.isfile(os.path.join(model_preds_folder, 'video_preds', f))] + model_preds = [ + f + for f in os.listdir(os.path.join(model_preds_folder, "video_preds")) + if os.path.isfile(os.path.join(model_preds_folder, "video_preds", f)) + ] ret_files = [] for file in model_preds: if video in file: @@ -59,14 +65,17 @@ def get_all_videos(model_preds_folders: list): # returned by streamlit's file_uploader ret_videos = set() for model_preds_folder in model_preds_folders: - model_preds = [f for f in os.listdir(os.path.join(model_preds_folder, 'video_preds')) - if os.path.isfile(os.path.join(model_preds_folder, 'video_preds', f))] + model_preds = [ + f + for f in os.listdir(os.path.join(model_preds_folder, "video_preds")) + if os.path.isfile(os.path.join(model_preds_folder, "video_preds", f)) + ] for file in model_preds: - if 'temporal' in file: - vid_file = file.split('_temporal_norm.csv')[0] + if "temporal" in file: + vid_file = file.split("_temporal_norm.csv")[0] ret_videos.add(vid_file) - elif 'temporal' not in file and 'pca' not in file: - vid_file = file.split('.csv')[0] + elif "temporal" not in file and "pca" not in file: + vid_file = file.split(".csv")[0] ret_videos.add(vid_file) return list(ret_videos) @@ -106,12 +115,16 @@ def get_df_box(df_orig, keypoint_names, model_names): def get_df_scatter(df_0, df_1, data_type, model_names, keypoint_names): df_scatters = [] for keypoint in keypoint_names: - df_scatters.append(pd.DataFrame({ - "img_file": df_0.img_file[df_0.set == data_type], - "keypoint": keypoint, - model_names[0]: df_0[keypoint][df_0.set == data_type], - model_names[1]: df_1[keypoint][df_1.set == data_type], - })) + df_scatters.append( + pd.DataFrame( + { + "img_file": df_0.img_file[df_0.set == data_type], + "keypoint": keypoint, + model_names[0]: df_0[keypoint][df_0.set == data_type], + model_names[1]: df_1[keypoint][df_1.set == data_type], + } + ) + ) return pd.concat(df_scatters) @@ -139,19 +152,23 @@ def build_precomputed_metrics_df( concat_dfs = defaultdict(list) for model_name, df_dict in dframes.items(): for metric_name, df in df_dict.items(): - if 'confidence' in metric_name: + if "confidence" in metric_name: df_ = compute_confidence( - df=df, keypoint_names=keypoint_names, model_name=model_name, **kwargs) + df=df, + keypoint_names=keypoint_names, + model_name=model_name, + **kwargs + ) concat_dfs[conf_error_key].append(df_) df_ = get_precomputed_error(df, keypoint_names, model_name, **kwargs) - if 'single' in metric_name: + if "single" in metric_name: concat_dfs[pcasv_error_key].append(df_) - elif 'multi' in metric_name: + elif "multi" in metric_name: concat_dfs[pcamv_error_key].append(df_) - elif 'temporal' in metric_name: + elif "temporal" in metric_name: concat_dfs[temp_norm_error_key].append(df_) - elif 'pixel' in metric_name: + elif "pixel" in metric_name: concat_dfs[pix_error_key].append(df_) for key in concat_dfs.keys(): @@ -168,15 +185,15 @@ def get_precomputed_error( df_ = df df_["model_name"] = model_name df_["mean"] = df_[keypoint_names[:-1]].mean(axis=1) - df_.rename(columns={df.columns[0]:'img_file'}, inplace=True) + df_.rename(columns={df.columns[0]: "img_file"}, inplace=True) return df_ @st.cache_data def compute_confidence( - df: pd.DataFrame, keypoint_names: List[str], model_name: str, **kwargs) -> pd.DataFrame: - + df: pd.DataFrame, keypoint_names: List[str], model_name: str, **kwargs +) -> pd.DataFrame: if df.shape[1] % 3 == 1: # get rid of "set" column if present tmp = df.iloc[:, :-1].to_numpy().reshape(df.shape[0], -1, 3) @@ -218,6 +235,6 @@ def get_model_folders(model_dir): def get_model_folders_vis(model_folders): fs = [] for f in model_folders: - fs.append(f.split('/')[-2:]) + fs.append(f.split("/")[-2:]) model_folders_vis = [os.path.join(*f) for f in fs] return model_folders_vis diff --git a/lightning_pose/apps/video_diagnostics.py b/lightning_pose/apps/video_diagnostics.py index 36726986..4bcee624 100644 --- a/lightning_pose/apps/video_diagnostics.py +++ b/lightning_pose/apps/video_diagnostics.py @@ -8,25 +8,30 @@ """ import argparse -import os import copy -import pandas as pd +import os +from collections import defaultdict from pathlib import Path + +import pandas as pd import streamlit as st -from collections import defaultdict -from lightning_pose.apps.utils import build_precomputed_metrics_df, get_col_names, concat_dfs -from lightning_pose.apps.utils import update_vid_metric_files_list, get_all_videos -from lightning_pose.apps.utils import get_model_folders, get_model_folders_vis -from lightning_pose.apps.plots import get_y_label -from lightning_pose.apps.plots import make_seaborn_catplot, make_plotly_catplot, plot_precomputed_traces +from lightning_pose.apps.plots import get_y_label, make_plotly_catplot, plot_precomputed_traces +from lightning_pose.apps.utils import ( + build_precomputed_metrics_df, + concat_dfs, + get_all_videos, + get_col_names, + get_model_folders, + get_model_folders_vis, + update_vid_metric_files_list, +) catplot_options = ["boxen", "box", "violin", "strip", "hist"] scale_options = ["linear", "log"] def run(): - args = parser.parse_args() st.title("Video Diagnostics") @@ -35,22 +40,28 @@ def run(): if args.make_dir: os.makedirs(args.model_dir, exist_ok=True) if not os.path.isdir(args.model_dir): - st.text(f"--model_dir {args.model_dir} does not exist." - f"\nPlease check the path and try again.") + st.text( + f"--model_dir {args.model_dir} does not exist." + f"\nPlease check the path and try again." + ) st.sidebar.header("Data Settings") # ----- selecting which models to use ----- model_folders = get_model_folders(args.model_dir) - + # get the last two levels of each path to be presented to user model_folders_vis = get_model_folders_vis(model_folders) - selected_models_vis = st.sidebar.multiselect("Select models", model_folders_vis, default=None) + selected_models_vis = st.sidebar.multiselect( + "Select models", model_folders_vis, default=None + ) # append this to full path - selected_models = ["/" + os.path.join(args.model_dir, f) for f in selected_models_vis] - + selected_models = [ + "/" + os.path.join(args.model_dir, f) for f in selected_models_vis + ] + # ----- selecting videos to analyze ----- all_videos_: list = get_all_videos(selected_models) @@ -58,12 +69,12 @@ def run(): video_to_plot = st.sidebar.selectbox("Select a video:", [*all_videos_], key="video") prediction_files = update_vid_metric_files_list( - video=video_to_plot, model_preds_folders=selected_models) - + video=video_to_plot, model_preds_folders=selected_models + ) + model_names = copy.copy(selected_models_vis) if len(prediction_files) > 0: # otherwise don't try to proceed - # --------------------------------------------------- # load data # --------------------------------------------------- @@ -76,26 +87,34 @@ def run(): model_folder = selected_models[p] for model_pred_file in model_pred_files: - model_pred_file_path = os.path.join(model_folder, 'video_preds', model_pred_file) + model_pred_file_path = os.path.join( + model_folder, "video_preds", model_pred_file + ) if not isinstance(model_pred_file, Path): model_pred_file.seek(0) # reset buffer after reading - if 'pca' in str(model_pred_file) \ - or 'temporal' in str(model_pred_file) \ - or 'pixel' in str(model_pred_file): + if ( + "pca" in str(model_pred_file) + or "temporal" in str(model_pred_file) + or "pixel" in str(model_pred_file) + ): dframe = pd.read_csv(model_pred_file_path, index_col=None) dframes_metrics[model_name][str(model_pred_file)] = dframe else: - dframe = pd.read_csv(model_pred_file_path, header=[1, 2], index_col=0) + dframe = pd.read_csv( + model_pred_file_path, header=[1, 2], index_col=0 + ) dframes_traces[model_name] = dframe - dframes_metrics[model_name]['confidence'] = dframe - data_types = dframe.iloc[:, -1].unique() + dframes_metrics[model_name]["confidence"] = dframe + # data_types = dframe.iloc[:, -1].unique() # edit model names if desired, to simplify plotting st.sidebar.write("Model display names (editable)") new_names = [] og_names = list(dframes_metrics.keys()) for name in og_names: - new_name = st.sidebar.text_input(label="name", value=name, label_visibility="hidden") + new_name = st.sidebar.text_input( + label="name", value=name, label_visibility="hidden" + ) new_names.append(new_name) # change dframes key names to new ones @@ -110,7 +129,8 @@ def run(): # concat dataframes, collapsing hierarchy and making df fatter. df_concat, keypoint_names = concat_dfs(dframes_traces) df_metrics = build_precomputed_metrics_df( - dframes=dframes_metrics, keypoint_names=keypoint_names) + dframes=dframes_metrics, keypoint_names=keypoint_names + ) metric_options = list(df_metrics.keys()) # --------------------------------------------------- @@ -128,34 +148,50 @@ def run(): plot_type = st.selectbox("Plot style:", catplot_options, key="plot_type") with col02: - plot_scale = st.radio("Y-axis scale", scale_options, key="plot_scale", horizontal=True) + plot_scale = st.radio( + "Y-axis scale", scale_options, key="plot_scale", horizontal=True + ) x_label = "Model Name" y_label = get_y_label(metric_to_plot) log_y = False if plot_scale == "linear" else True - + # DB: commented out seaborn for visual coherence # fig_cat = make_seaborn_catplot( - # x="model_name", y="mean", data=df_metrics[metric_to_plot], log_y=log_y, x_label=x_label, - # y_label=y_label, title="Average over all keypoints", plot_type=plot_type) + # x="model_name", y="mean", data=df_metrics[metric_to_plot], log_y=log_y, + # x_label=x_label, y_label=y_label, title="Average over all keypoints", + # plot_type=plot_type + # ) # st.pyplot(fig_cat) # select keypoint to plot keypoint_to_plot = st.selectbox( - "Select a keypoint:", pd.Series([*keypoint_names, "mean"]), key="keypoint_to_plot", + "Select a keypoint:", + pd.Series([*keypoint_names, "mean"]), + key="keypoint_to_plot", ) - + if plot_type != "hist": # show plot per keypoint plotly_flex_fig = make_plotly_catplot( - x="model_name", y=keypoint_to_plot, data=df_metrics[metric_to_plot], - x_label=x_label, y_label=y_label, title=keypoint_to_plot, plot_type=plot_type, - log_y=log_y + x="model_name", + y=keypoint_to_plot, + data=df_metrics[metric_to_plot], + x_label=x_label, + y_label=y_label, + title=keypoint_to_plot, + plot_type=plot_type, + log_y=log_y, ) else: plotly_flex_fig = make_plotly_catplot( - x=keypoint_to_plot, y=None, data=df_metrics[metric_to_plot], - x_label=y_label, y_label="Frame count", title=keypoint_to_plot, plot_type="hist", + x=keypoint_to_plot, + y=None, + data=df_metrics[metric_to_plot], + x_label=y_label, + y_label="Frame count", + title=keypoint_to_plot, + plot_type="hist", ) st.plotly_chart(plotly_flex_fig) @@ -168,8 +204,9 @@ def run(): with col10: models = st.multiselect( - "Models:", - pd.Series(list(dframes_metrics.keys())), default=list(dframes_metrics.keys()) + "Models:", + pd.Series(list(dframes_metrics.keys())), + default=list(dframes_metrics.keys()), ) with col11: @@ -181,10 +218,9 @@ def run(): if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--model_dir', type=str, default=[]) - parser.add_argument('--make_dir', action='store_true', default=False) + parser.add_argument("--model_dir", type=str, default=[]) + parser.add_argument("--make_dir", action="store_true", default=False) run() diff --git a/lightning_pose/callbacks.py b/lightning_pose/callbacks.py index a3efeb2e..31404537 100644 --- a/lightning_pose/callbacks.py +++ b/lightning_pose/callbacks.py @@ -1,6 +1,6 @@ -from lightning.pytorch.callbacks import Callback import lightning.pytorch as pl import torch +from lightning.pytorch.callbacks import Callback class AnnealWeight(Callback): diff --git a/lightning_pose/data/dali.py b/lightning_pose/data/dali.py index 9303e21b..5cd90031 100644 --- a/lightning_pose/data/dali.py +++ b/lightning_pose/data/dali.py @@ -1,18 +1,17 @@ """Data pipelines based on efficient video reading by nvidia dali package.""" -from nvidia.dali import pipeline_def +from typing import Dict, List, Literal, Optional, Union + +import numpy as np import nvidia.dali.fn as fn -from nvidia.dali.plugin.pytorch import LastBatchPolicy -from nvidia.dali.plugin.pytorch import DALIGenericIterator import nvidia.dali.types as types -from omegaconf import DictConfig import torch -import numpy as np -from typing import List, Dict, Optional, Union, Literal, Tuple - +from nvidia.dali import pipeline_def +from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy +from omegaconf import DictConfig from lightning_pose.data import _IMAGENET_MEAN, _IMAGENET_STD -from lightning_pose.data.utils import count_frames, UnlabeledBatchDict +from lightning_pose.data.utils import UnlabeledBatchDict, count_frames _DALI_DEVICE = "gpu" if torch.cuda.is_available() else "cpu" @@ -151,7 +150,6 @@ def __len__(self) -> int: @staticmethod def _dali_output_to_tensors(batch: list) -> UnlabeledBatchDict: - # always batch_size=1 # shape (sequence_length, 3, H, W) frames = batch[0]["frames"][0, :, :, :, :] @@ -201,15 +199,20 @@ def num_iters(self) -> int: if self.model_type == "base": return int(np.ceil(self.frame_count / (pipe_dict["sequence_length"]))) elif self.model_type == "context": - if pipe_dict["step"] == 1: # 0-5, 1-6, 2-7, 3-8, 4-9 ... + if pipe_dict["step"] == 1: # 0-5, 1-6, 2-7, 3-8, 4-9 ... return int(np.ceil(self.frame_count / (pipe_dict["batch_size"]))) elif pipe_dict["step"] == pipe_dict["sequence_length"]: # taking the floor because during training we don't care about missing the last # non-full batch. we prefer having fewer batches but valid. - return int(np.floor( - self.frame_count / (pipe_dict["batch_size"] * pipe_dict["sequence_length"]))) - elif (pipe_dict["batch_size"] == 1) \ - and (pipe_dict["step"] == (pipe_dict["sequence_length"] - 4)): + return int( + np.floor( + self.frame_count + / (pipe_dict["batch_size"] * pipe_dict["sequence_length"]) + ) + ) + elif (pipe_dict["batch_size"] == 1) and ( + pipe_dict["step"] == (pipe_dict["sequence_length"] - 4) + ): # the case of prediction with a single sequence at a time and internal model # reshapes if pipe_dict["step"] <= 0: @@ -218,14 +221,18 @@ def num_iters(self) -> int: "cfg.dali.context.predict.sequence_length to be > 4" ) # remove the first sequence - data_except_first_batch = self.frame_count - pipe_dict["sequence_length"] + data_except_first_batch = ( + self.frame_count - pipe_dict["sequence_length"] + ) # calculate how many "step"s are needed to get at least to the end # count back the first sequence - num_iters = int(np.ceil(data_except_first_batch / pipe_dict["step"])) + 1 + num_iters = ( + int(np.ceil(data_except_first_batch / pipe_dict["step"])) + 1 + ) return num_iters else: raise NotImplementedError - + def _setup_pipe_dict(self, filenames: List[str], imgaug: str) -> Dict[str, dict]: """all of the pipe() args in one place""" dict_args = {} diff --git a/lightning_pose/data/datamodules.py b/lightning_pose/data/datamodules.py index d817df2b..2bb13943 100644 --- a/lightning_pose/data/datamodules.py +++ b/lightning_pose/data/datamodules.py @@ -1,22 +1,20 @@ """Data modules split a dataset into train, val, and test modules.""" -from nvidia.dali.plugin.pytorch import LastBatchPolicy -import os -from omegaconf import DictConfig +from typing import List, Literal, Optional, Union + import lightning.pytorch as pl import torch +from lightning.pytorch.utilities import CombinedLoader +from omegaconf import DictConfig from torch.utils.data import DataLoader, random_split -from typing import Dict, List, Literal, Optional, Tuple, Union, TypedDict -from lightning_pose.data.dali import PrepareDALI, LitDaliWrapper +from lightning_pose.data.dali import PrepareDALI from lightning_pose.data.utils import ( - split_sizes_from_probabilities, - compute_num_train_frames, SemiSupervisedDataLoaderDict, + compute_num_train_frames, + split_sizes_from_probabilities, ) from lightning_pose.utils.io import check_video_paths -from lightning.pytorch.utilities import CombinedLoader - _TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -72,9 +70,12 @@ def __init__( self.torch_seed = torch_seed def setup(self, stage: Optional[str] = None): # stage arg needed for ptl - datalen = self.dataset.__len__() - print("Number of labeled images in the full dataset (train+val+test): {}".format(datalen)) + print( + "Number of labeled images in the full dataset (train+val+test): {}".format( + datalen + ) + ) # split data based on provided probabilities data_splits_list = split_sizes_from_probabilities( @@ -92,8 +93,9 @@ def setup(self, stage: Optional[str] = None): # stage arg needed for ptl # further subsample training data if desired if self.train_frames is not None: - - n_frames = compute_num_train_frames(len(self.train_dataset), self.train_frames) + n_frames = compute_num_train_frames( + len(self.train_dataset), self.train_frames + ) if n_frames < len(self.train_dataset): # split the data a second time to reflect further subsampling from diff --git a/lightning_pose/data/datasets.py b/lightning_pose/data/datasets.py index f6502dc8..e37a1afc 100644 --- a/lightning_pose/data/datasets.py +++ b/lightning_pose/data/datasets.py @@ -1,18 +1,22 @@ """Dataset objects store images, labels, and functions for manipulation.""" +import os +from typing import Callable, List, Literal, Optional + import imgaug.augmenters as iaa import numpy as np -import os import pandas as pd -from PIL import Image import torch -from torchvision import transforms -from typing import Callable, List, Literal, Optional +from PIL import Image from torchtyping import TensorType +from torchvision import transforms from lightning_pose.data import _IMAGENET_MEAN, _IMAGENET_STD -from lightning_pose.data.utils import generate_heatmaps -from lightning_pose.data.utils import BaseLabeledExampleDict, HeatmapLabeledExampleDict +from lightning_pose.data.utils import ( + BaseLabeledExampleDict, + HeatmapLabeledExampleDict, + generate_heatmaps, +) from lightning_pose.utils.io import get_keypoint_names _TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -83,7 +87,9 @@ def __init__( csv_file = options[0] csv_data = pd.read_csv(csv_file, header=header_rows, index_col=0) - self.keypoint_names = get_keypoint_names(csv_file=csv_file, header_rows=header_rows) + self.keypoint_names = get_keypoint_names( + csv_file=csv_file, header_rows=header_rows + ) self.image_names = list(csv_data.index) self.keypoints = torch.tensor(csv_data.to_numpy(), dtype=torch.float32) # convert to x,y coordinates @@ -114,7 +120,6 @@ def __len__(self) -> int: return len(self.image_names) def __getitem__(self, idx: int) -> BaseLabeledExampleDict: - img_name = self.image_names[idx] keypoints_on_image = self.keypoints[idx] @@ -190,7 +195,9 @@ def __getitem__(self, idx: int) -> BaseLabeledExampleDict: image_frames_tensor = torch.unsqueeze(transformed_image, dim=0) else: image_expand = torch.unsqueeze(transformed_image, dim=0) - image_frames_tensor = torch.cat((image_frames_tensor, image_expand), dim=0) + image_frames_tensor = torch.cat( + (image_frames_tensor, image_expand), dim=0 + ) transformed_images = image_frames_tensor @@ -261,8 +268,8 @@ def __init__( @property def output_shape(self) -> tuple: return ( - self.height // 2 ** self.downsample_factor, - self.width // 2 ** self.downsample_factor, + self.height // 2**self.downsample_factor, + self.width // 2**self.downsample_factor, ) def compute_heatmap( @@ -323,7 +330,9 @@ def __getitem__(self, idx: int) -> HeatmapLabeledExampleDict: """ example_dict: BaseLabeledExampleDict = super().__getitem__(idx) - if len(self.imgaug_transform) == 1 and isinstance(self.imgaug_transform[0], iaa.Resize): + if len(self.imgaug_transform) == 1 and isinstance( + self.imgaug_transform[0], iaa.Resize + ): # we have a deterministic resizing augmentation; use precomputed heatmaps example_dict["heatmaps"] = self.label_heatmaps[idx] else: diff --git a/lightning_pose/data/utils.py b/lightning_pose/data/utils.py index 8344606f..cfaef3d1 100644 --- a/lightning_pose/data/utils.py +++ b/lightning_pose/data/utils.py @@ -1,14 +1,14 @@ """Dataset/data module utilities.""" +from typing import Any, List, Literal, Optional, Tuple, TypedDict, Union + import imgaug.augmenters as iaa -from kornia import image_to_tensor +import lightning.pytorch as pl import numpy as np -from nvidia.dali.plugin.pytorch import DALIGenericIterator import torch +from nvidia.dali.plugin.pytorch import DALIGenericIterator from torchtyping import TensorType from typeguard import typechecked -from typing import List, Literal, Optional, Tuple, Union, Dict, Any, TypedDict -import lightning.pytorch as pl # below are a bunch of classes that streamline data typechecking @@ -43,7 +43,9 @@ class BaseLabeledBatchDict(TypedDict): class HeatmapLabeledBatchDict(BaseLabeledBatchDict): """Batch type for heatmap labeled data.""" - heatmaps: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width", float] + heatmaps: TensorType[ + "batch", "num_keypoints", "heatmap_height", "heatmap_width", float + ] class UnlabeledBatchDict(TypedDict): @@ -51,7 +53,9 @@ class UnlabeledBatchDict(TypedDict): frames: Union[ TensorType["seq_len", "RGB":3, "image_height", "image_width", float], - TensorType["seq_len", "context":5, "RGB":3, "image_height", "image_width", float], + TensorType[ + "seq_len", "context":5, "RGB":3, "image_height", "image_width", float + ], ] transforms: Union[ TensorType["seq_len", "h":2, "w":3, float], @@ -176,7 +180,9 @@ def dataset_length(self) -> int: name = "%s_dataset" % self.cond return len(getattr(self.data_module, name)) - def get_loader(self) -> Union[torch.utils.data.DataLoader, SemiSupervisedDataLoaderDict]: + def get_loader( + self, + ) -> Union[torch.utils.data.DataLoader, SemiSupervisedDataLoaderDict]: if self.cond == "train": return self.data_module.train_dataloader() if self.cond == "val": @@ -197,8 +203,7 @@ def verify_labeled_loader( return labeled_loader def iterate_over_dataloader( - self, - loader: torch.utils.data.DataLoader + self, loader: torch.utils.data.DataLoader ) -> Tuple[ TensorType["num_examples", Any], Union[ @@ -220,7 +225,10 @@ def iterate_over_dataloader( concat_images = None # assert that indeed the number of columns does not change after concatenation, # and that the number of rows is the dataset length. - assert concat_keypoints.shape == (self.dataset_length, keypoints_list[0].shape[1]) + assert concat_keypoints.shape == ( + self.dataset_length, + keypoints_list[0].shape[1], + ) return concat_keypoints, concat_images def __call__( @@ -334,10 +342,7 @@ def compute_num_train_frames( else: if train_frames >= len_train_dataset: # take max number of train frames - print( - f"Warning! Requested training frames exceeds training set size; " - f"using all" - ) + print("Warning! Requested training frames exceeds training set size; using all") n_train_frames = len_train_dataset elif train_frames == 1: # assume this is a fraction; use full dataset @@ -354,7 +359,7 @@ def compute_num_train_frames( return n_train_frames -#@typechecked +# @typechecked def generate_heatmaps( keypoints: TensorType["batch", "num_keypoints", 2], height: int, @@ -395,7 +400,7 @@ def generate_heatmaps( heatmaps = (yy - keypoints[:, :, :, :1]) ** 2 # also flipped order here heatmaps += (xx - keypoints[:, :, :, 1:]) ** 2 # also flipped order here heatmaps *= -1 - heatmaps /= 2 * sigma ** 2 + heatmaps /= 2 * sigma**2 heatmaps = torch.exp(heatmaps) # normalize all heatmaps to one heatmaps = heatmaps / torch.sum(heatmaps, dim=(2, 3), keepdim=True) @@ -403,7 +408,8 @@ def generate_heatmaps( # (all zeros heatmaps are ignored in the supervised heatmap loss) if uniform_heatmaps: filler_heatmap = torch.ones( - (out_height, out_width), device=keypoints.device) / (out_height * out_width) + (out_height, out_width), device=keypoints.device + ) / (out_height * out_width) else: filler_heatmap = torch.zeros((out_height, out_width), device=keypoints.device) @@ -411,7 +417,7 @@ def generate_heatmaps( return heatmaps -#@typechecked +# @typechecked def evaluate_heatmaps_at_location( heatmaps: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"], locs: TensorType["batch", "num_keypoints", 2], @@ -424,38 +430,42 @@ def evaluate_heatmaps_at_location( taking all pixels within two standard deviations of the predicted pixel.""" pix_to_consider = int(np.floor(sigma * num_stds)) # get all pixels within num_stds. num_pad = pix_to_consider - heatmaps_padded = torch.zeros(( - heatmaps.shape[0], - heatmaps.shape[1], - heatmaps.shape[2] + num_pad * 2, - heatmaps.shape[3] + num_pad * 2, + heatmaps_padded = torch.zeros( + ( + heatmaps.shape[0], + heatmaps.shape[1], + heatmaps.shape[2] + num_pad * 2, + heatmaps.shape[3] + num_pad * 2, ), device=heatmaps.device, ) heatmaps_padded[:, :, num_pad:-num_pad, num_pad:-num_pad] = heatmaps - i = torch.arange(heatmaps_padded.shape[0], device=heatmaps_padded.device).reshape(-1, 1, 1, 1) - j = torch.arange(heatmaps_padded.shape[1], device=heatmaps_padded.device).reshape(1, -1, 1, 1) + i = torch.arange(heatmaps_padded.shape[0], device=heatmaps_padded.device).reshape( + -1, 1, 1, 1 + ) + j = torch.arange(heatmaps_padded.shape[1], device=heatmaps_padded.device).reshape( + 1, -1, 1, 1 + ) k = locs[:, :, None, 1, None].type(torch.int64) + num_pad - l = locs[:, :, 0, None, None].type(torch.int64) + num_pad + m = locs[:, :, 0, None, None].type(torch.int64) + num_pad offsets = list(np.arange(-pix_to_consider, pix_to_consider + 1)) vals_all = [] for offset in offsets: k_offset = k + offset for offset_2 in offsets: - l_offset = l + offset_2 + m_offset = m + offset_2 # get rid of singleton dims - vals = heatmaps_padded[i, j, k_offset, l_offset].squeeze(-1).squeeze(-1) + vals = heatmaps_padded[i, j, k_offset, m_offset].squeeze(-1).squeeze(-1) vals_all.append(vals) vals = torch.stack(vals_all, 0).sum(0) return vals -#@typechecked +# @typechecked def undo_affine_transform( keypoints: TensorType["seq_len", "num_keypoints", 2], transform: Union[TensorType["seq_len", 2, 3], TensorType[2, 3]], ) -> TensorType["seq_len", "num_keypoints", 2]: - # add 1s to get keypoints in projective geometry coords ones = torch.ones( (keypoints.shape[0], keypoints.shape[1], 1), @@ -474,18 +484,24 @@ def undo_affine_transform( mats_inv_torch = [] for idx in range(mat.shape[0]): mat_inv_ = torch.linalg.inv(mat[idx, :, :2]) - mat_inv = torch.concat([mat_inv_, torch.matmul(-mat_inv_, mat[idx, :, -1, None])], dim=1) - mats_inv_torch.append(torch.tensor( - torch.transpose(mat_inv, 1, 0), - dtype=keypoints.dtype, - device=keypoints.device, - requires_grad=True, - )) + mat_inv = torch.concat( + [mat_inv_, torch.matmul(-mat_inv_, mat[idx, :, -1, None])], dim=1 + ) + mats_inv_torch.append( + torch.tensor( + torch.transpose(mat_inv, 1, 0), + dtype=keypoints.dtype, + device=keypoints.device, + requires_grad=True, + ) + ) # make a single block of inverse matrices if len(mats_inv_torch) == 1: # replicate this inverse matrix for each element of the batch - mat_inv_torch = torch.tile(mats_inv_torch[0].unsqueeze(0), dims=(keypoints.shape[0], 1, 1)) + mat_inv_torch = torch.tile( + mats_inv_torch[0].unsqueeze(0), dims=(keypoints.shape[0], 1, 1) + ) else: # different transformation for each element of the batch mat_inv_torch = torch.stack(mats_inv_torch, dim=0) diff --git a/lightning_pose/losses/factory.py b/lightning_pose/losses/factory.py index 82474657..df4738bc 100644 --- a/lightning_pose/losses/factory.py +++ b/lightning_pose/losses/factory.py @@ -1,9 +1,10 @@ """High-level loss class that orchestrates the individual losses.""" +from typing import Dict, List, Literal, Optional, Tuple, Union + import lightning.pytorch as pl import torch from torchtyping import TensorType -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule from lightning_pose.losses.losses import get_loss_classes diff --git a/lightning_pose/losses/helpers.py b/lightning_pose/losses/helpers.py index 15e4ede9..c8c261ff 100644 --- a/lightning_pose/losses/helpers.py +++ b/lightning_pose/losses/helpers.py @@ -1,9 +1,10 @@ """Helper functions for losses.""" -import torch +from typing import Dict, Literal, Union + import numpy as np +import torch from typeguard import typechecked -from typing import Tuple, Union, Dict, List, Literal class EmpiricalEpsilon: diff --git a/lightning_pose/losses/losses.py b/lightning_pose/losses/losses.py index d672f117..b77ef247 100644 --- a/lightning_pose/losses/losses.py +++ b/lightning_pose/losses/losses.py @@ -19,27 +19,25 @@ """ -from kornia.losses import js_div_loss_2d, kl_div_loss_2d -from omegaconf import ListConfig +import warnings +from typing import Dict, List, Literal, Optional, Tuple, Type, Union + import lightning.pytorch as pl import torch +from kornia.losses import js_div_loss_2d, kl_div_loss_2d +from omegaconf import ListConfig from torch.nn import functional as F from torchtyping import TensorType from typeguard import typechecked -from typing import Any, Callable, Dict, Tuple, List, Literal, Optional, Union, Type -import warnings from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule from lightning_pose.data.utils import generate_heatmaps -from lightning_pose.utils.pca import ( - KeypointPCA, - format_multiview_data_for_pca, -) +from lightning_pose.utils.pca import KeypointPCA _TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -#@typechecked +# @typechecked class Loss(pl.LightningModule): """Parent class for all losses.""" @@ -121,7 +119,7 @@ def __call__(self, *args, **kwargs): raise NotImplementedError -#@typechecked +# @typechecked class HeatmapLoss(Loss): """Parent class for different heatmap losses (MSE, Wasserstein, etc).""" @@ -172,7 +170,7 @@ def __call__( return self.weight * scalar_loss, logs -#@typechecked +# @typechecked class HeatmapMSELoss(HeatmapLoss): """MSE loss between heatmaps.""" @@ -197,7 +195,7 @@ def compute_loss( return loss -#@typechecked +# @typechecked class HeatmapKLLoss(HeatmapLoss): """Kullback-Leibler loss between heatmaps.""" @@ -224,7 +222,7 @@ def compute_loss( return loss[0] -#@typechecked +# @typechecked class HeatmapJSLoss(HeatmapLoss): """Kullback-Leibler loss between heatmaps.""" @@ -251,7 +249,7 @@ def compute_loss( return loss[0] -#@typechecked +# @typechecked class PCALoss(Loss): """Penalize predictions that fall outside a low-dimensional subspace.""" @@ -343,7 +341,7 @@ def __call__( return self.weight * scalar_loss, logs -#@typechecked +# @typechecked class TemporalLoss(Loss): """Penalize temporal differences for each target. @@ -432,7 +430,7 @@ def __call__( return self.weight * scalar_loss, logs -#@typechecked +# @typechecked class TemporalHeatmapLoss(Loss): """Penalize temporal differences for each heatmap. @@ -536,7 +534,7 @@ def __call__( return self.weight * scalar_loss, logs -#@typechecked +# @typechecked class UnimodalLoss(Loss): """Encourage heatmaps to be unimodal using various measures.""" @@ -586,13 +584,13 @@ def remove_nans( TensorType["num_valid_keypoints", "heatmap_height", "heatmap_width"], ]: """Remove nans from targets and predictions. - Args: + Args: targets: (batch, num_keypoints, heatmap_height, heatmap_width) predictions: (batch, num_keypoints, heatmap_height, heatmap_width) confidences: (batch, num_keypoints) Returns: - clean targets: (num_valid_keypoints, heatmap_height, heatmap_width), concatenated across different images and keypoints - clean predictions: (num_valid_keypoints, heatmap_height, heatmap_width), concatenated across different images and keypoints + clean targets: concatenated across different images and keypoints + clean predictions: concatenated across different images and keypoints """ # use confidences to get rid of unsupervised targets with likely occlusions idxs_ignore = confidences < self.prob_threshold @@ -632,8 +630,10 @@ def __call__( ) -> Tuple[TensorType[()], List[dict]]: """Compute unimodal loss. Args: - keypoints_pred_augmented: (batch, 2 * num_keypoints) these are in the augmented image space - heatmaps_pred: (batch, num_keypoints, heatmap_height, heatmap_width) these are also in the augmented space, matching the keypoints_pred_augmented""" + keypoints_pred_augmented: these are in the augmented image space + heatmaps_pred: these are also in the augmented space, matching the + keypoints_pred_augmented + """ # turn keypoint predictions into unimodal heatmaps keypoints_pred = keypoints_pred_augmented.reshape(keypoints_pred_augmented.shape[0], -1, 2) @@ -659,7 +659,7 @@ def __call__( return self.weight * scalar_loss, logs -#@typechecked +# @typechecked class RegressionMSELoss(Loss): """MSE loss between ground truth and predicted coordinates.""" @@ -711,7 +711,7 @@ def __call__( return self.weight * scalar_loss, logs -#@typechecked +# @typechecked class RegressionRMSELoss(RegressionMSELoss): """Root MSE loss between ground truth and predicted coordinates.""" diff --git a/lightning_pose/metrics.py b/lightning_pose/metrics.py index 90293493..7d4ef429 100644 --- a/lightning_pose/metrics.py +++ b/lightning_pose/metrics.py @@ -1,7 +1,8 @@ +from typing import Union + import numpy as np -from omegaconf import DictConfig import torch -from typing import Optional, Union +from omegaconf import DictConfig from typeguard import typechecked from lightning_pose.utils.pca import KeypointPCA @@ -110,7 +111,6 @@ def pca_multiview_reprojection_error( shape (samples, n_keypoints) """ - from lightning_pose.utils.pca import format_multiview_data_for_pca if not isinstance(keypoints_pred, torch.Tensor): keypoints_pred = torch.tensor(keypoints_pred, device=pca.device, dtype=torch.float32) diff --git a/lightning_pose/models/backbones/torchvision.py b/lightning_pose/models/backbones/torchvision.py index b7107cbb..6e3b0c3b 100644 --- a/lightning_pose/models/backbones/torchvision.py +++ b/lightning_pose/models/backbones/torchvision.py @@ -1,4 +1,5 @@ from collections import OrderedDict + import torch import torchvision.models as tvmodels from typeguard import typechecked @@ -29,7 +30,7 @@ def build_backbone( # load resnet50 pretrained using SimCLR on imagenet from pl_bolts.models.self_supervised import SimCLR - ckpt_url = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" + ckpt_url = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" # noqa: E501 simclr = SimCLR.load_from_checkpoint(ckpt_url, strict=False) base = simclr.encoder @@ -37,9 +38,9 @@ def build_backbone( base = getattr(tvmodels, "resnet50")(weights=None) backbone_type = "_".join(backbone_arch.split("_")[2:]) if backbone_type == "apose": - ckpt_url = "https://download.openmmlab.com/mmpose/animal/resnet/res50_animalpose_256x256-e1f30bff_20210426.pth" + ckpt_url = "https://download.openmmlab.com/mmpose/animal/resnet/res50_animalpose_256x256-e1f30bff_20210426.pth" # noqa: E501 else: - ckpt_url = "https://download.openmmlab.com/mmpose/animal/resnet/res50_ap10k_256x256-35760eb8_20211029.pth" + ckpt_url = "https://download.openmmlab.com/mmpose/animal/resnet/res50_ap10k_256x256-35760eb8_20211029.pth" # noqa: E501 state_dict = torch.hub.load_state_dict_from_url(ckpt_url)["state_dict"] new_state_dict = OrderedDict() @@ -53,11 +54,11 @@ def build_backbone( base = getattr(tvmodels, "resnet50")(weights=None) backbone_type = "_".join(backbone_arch.split("_")[2:]) if backbone_type == "jhmdb": - ckpt_url = "https://download.openmmlab.com/mmpose/top_down/resnet/res50_jhmdb_sub3_256x256-c4ec1a0b_20201122.pth" + ckpt_url = "https://download.openmmlab.com/mmpose/top_down/resnet/res50_jhmdb_sub3_256x256-c4ec1a0b_20201122.pth" # noqa: E501 elif backbone_type == "res_rle": - ckpt_url = "https://download.openmmlab.com/mmpose/top_down/deeppose/deeppose_res50_mpii_256x256_rle-5f92a619_20220504.pth" + ckpt_url = "https://download.openmmlab.com/mmpose/top_down/deeppose/deeppose_res50_mpii_256x256_rle-5f92a619_20220504.pth" # noqa: E501 elif backbone_type == "top_res": - ckpt_url = "https://download.openmmlab.com/mmpose/top_down/resnet/res50_mpii_256x256-418ffc88_20200812.pth" + ckpt_url = "https://download.openmmlab.com/mmpose/top_down/resnet/res50_mpii_256x256-418ffc88_20200812.pth" # noqa: E501 state_dict = torch.hub.load_state_dict_from_url(ckpt_url)["state_dict"] new_state_dict = OrderedDict() diff --git a/lightning_pose/models/backbones/vits.py b/lightning_pose/models/backbones/vits.py index 8dd8b5c6..006d36f1 100644 --- a/lightning_pose/models/backbones/vits.py +++ b/lightning_pose/models/backbones/vits.py @@ -1,6 +1,7 @@ from functools import partial -from segment_anything.modeling import ImageEncoderViT + import torch +from segment_anything.modeling import ImageEncoderViT from typeguard import typechecked diff --git a/lightning_pose/models/base.py b/lightning_pose/models/base.py index b5ee4076..cf12f89d 100644 --- a/lightning_pose/models/base.py +++ b/lightning_pose/models/base.py @@ -1,20 +1,21 @@ """Base class for backbone that acts as a feature extractor.""" -from omegaconf import DictConfig -from lightning.pytorch import LightningModule +from typing import Dict, Literal, Optional, Union + import torch +from lightning.pytorch import LightningModule +from omegaconf import DictConfig from torch.optim import Adam from torch.optim.lr_scheduler import MultiStepLR from torchtyping import TensorType from typeguard import typechecked -from typing import Dict, Literal, Optional, Union from lightning_pose.data.utils import ( BaseLabeledBatchDict, HeatmapLabeledBatchDict, - UnlabeledBatchDict, SemiSupervisedBatchDict, SemiSupervisedHeatmapBatchDict, + UnlabeledBatchDict, ) from lightning_pose.models import ALLOWED_BACKBONES diff --git a/lightning_pose/models/heatmap_tracker.py b/lightning_pose/models/heatmap_tracker.py index aefd3e6c..1c7207b9 100644 --- a/lightning_pose/models/heatmap_tracker.py +++ b/lightning_pose/models/heatmap_tracker.py @@ -1,21 +1,17 @@ """Models that produce heatmaps of keypoints from images.""" +from typing import Optional, Tuple, Union + +import torch from kornia.filters import filter2d -from kornia.geometry.subpix import spatial_softmax2d, spatial_expectation2d +from kornia.geometry.subpix import spatial_expectation2d, spatial_softmax2d from kornia.geometry.transform.pyramid import _get_pyramid_gaussian_kernel -import numpy as np from omegaconf import DictConfig -import torch from torch import nn from torchtyping import TensorType -from typeguard import typechecked -from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, Union from typing_extensions import Literal -from torch.optim import Adam -from torch.optim.lr_scheduler import MultiStepLR, ReduceLROnPlateau from lightning_pose.data.utils import ( - BaseLabeledBatchDict, HeatmapLabeledBatchDict, UnlabeledBatchDict, evaluate_heatmaps_at_location, @@ -29,7 +25,7 @@ def upsample( inputs: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"], -) -> TensorType["batch", "num_keypoints", "2 x heatmap_height", "2 x heatmap_width"]: +) -> TensorType["batch", "num_keypoints", "two_x_heatmap_height", "two_x_heatmap_width"]: """Upsample batch of heatmaps using interpolation (no learned weights). This is a copy of kornia's pyrup function but with better defaults. @@ -96,7 +92,7 @@ def __init__( self.num_targets = num_keypoints * 2 self.loss_factory = loss_factory # TODO: downsample_factor may be in mismatch between datamodule and model. - self.downsample_factor = downsample_factor + self.downsample_factor = downsample_factor self.upsampling_layers = self.make_upsampling_layers() self.initialize_upsampling_layers() self.output_shape = output_shape @@ -297,7 +293,7 @@ def forward( ], ) -> TensorType["num_valid_outputs", "num_keypoints", "heatmap_height", "heatmap_width"]: """Forward pass through the network.""" - # we get one representation for each desired output. + # we get one representation for each desired output. # in the case of unsupervised sequences + context, we have outputs for all images but the # first two and last two. # this is all handled internally by get_representations() @@ -425,7 +421,10 @@ def get_loss_inputs_unlabeled(self, batch: UnlabeledBatchDict) -> dict: # undo augmentation if needed if batch["transforms"].shape[-1] == 3: # reshape to (seq_len, n_keypoints, 2) - pred_kps = torch.reshape(predicted_keypoints_augmented, (predicted_keypoints_augmented.shape[0], -1, 2)) + pred_kps = torch.reshape( + predicted_keypoints_augmented, + (predicted_keypoints_augmented.shape[0], -1, 2) + ) # undo pred_kps = undo_affine_transform(pred_kps, batch["transforms"]) # reshape to (seq_len, n_keypoints * 2) @@ -434,8 +433,8 @@ def get_loss_inputs_unlabeled(self, batch: UnlabeledBatchDict) -> dict: predicted_keypoints = predicted_keypoints_augmented return { - "heatmaps_pred": predicted_heatmaps, # if augmented, these are the augmented heatmaps - "keypoints_pred": predicted_keypoints, # if we augmented, these are the original keypoints - "keypoints_pred_augmented": predicted_keypoints_augmented, # these keypoints match heatmaps_pred, all are augmented + "heatmaps_pred": predicted_heatmaps, # if augmented, augmented heatmaps + "keypoints_pred": predicted_keypoints, # if augmented, original keypoints + "keypoints_pred_augmented": predicted_keypoints_augmented, # match heatmaps_pred "confidences": confidence, } diff --git a/lightning_pose/models/heatmap_tracker_mhcrnn.py b/lightning_pose/models/heatmap_tracker_mhcrnn.py index bc7137ff..30e453d9 100644 --- a/lightning_pose/models/heatmap_tracker_mhcrnn.py +++ b/lightning_pose/models/heatmap_tracker_mhcrnn.py @@ -1,18 +1,19 @@ """Models that produce heatmaps of keypoints from images.""" +from typing import Optional, Tuple, Union + +import torch from kornia.geometry.subpix import spatial_softmax2d from omegaconf import DictConfig -import torch from torch import nn from torchtyping import TensorType from typeguard import typechecked -from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, Union from typing_extensions import Literal from lightning_pose.data.utils import ( - undo_affine_transform, HeatmapLabeledBatchDict, UnlabeledBatchDict, + undo_affine_transform, ) from lightning_pose.losses.factory import LossFactory from lightning_pose.models import ALLOWED_BACKBONES @@ -93,8 +94,8 @@ def heatmaps_from_representations( self, representations: TensorType["batch", "features", "rep_height", "rep_width", "frames"], ) -> Tuple[ - TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"], - TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"], + TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"], + TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"], ]: """Handle context frames then upsample to get final heatmaps.""" # permute to shape (frames, batch, features, rep_height, rep_width) @@ -111,8 +112,8 @@ def forward( TensorType["batch", "frames", "channels":3, "image_height", "image_width"] ], ) -> Tuple[ - TensorType["num_valid_outputs", "num_keypoints", "heatmap_height", "heatmap_width"], - TensorType["num_valid_outputs", "num_keypoints", "heatmap_height", "heatmap_width"], + TensorType["num_valid_outputs", "num_keypoints", "heatmap_height", "heatmap_width"], + TensorType["num_valid_outputs", "num_keypoints", "heatmap_height", "heatmap_width"], ]: """Forward pass through the network.""" diff --git a/lightning_pose/models/regression_tracker.py b/lightning_pose/models/regression_tracker.py index adc072d0..77f26579 100644 --- a/lightning_pose/models/regression_tracker.py +++ b/lightning_pose/models/regression_tracker.py @@ -1,18 +1,17 @@ """Models that produce (x, y) coordinates of keypoints from images.""" -from omegaconf import DictConfig +from typing import Optional, Tuple, Union + import torch +from omegaconf import DictConfig from torch import nn from torchtyping import TensorType from typeguard import typechecked -from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from typing_extensions import Literal from lightning_pose.data.utils import ( - evaluate_heatmaps_at_location, - undo_affine_transform, BaseLabeledBatchDict, UnlabeledBatchDict, + undo_affine_transform, ) from lightning_pose.losses.factory import LossFactory from lightning_pose.losses.losses import RegressionRMSELoss diff --git a/lightning_pose/utils/__init__.py b/lightning_pose/utils/__init__.py index 2bd0648f..bcaff385 100644 --- a/lightning_pose/utils/__init__.py +++ b/lightning_pose/utils/__init__.py @@ -1,5 +1,5 @@ -from omegaconf import ListConfig import torch +from omegaconf import ListConfig _TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/lightning_pose/utils/fiftyone.py b/lightning_pose/utils/fiftyone.py index 136c7dcd..d21e4f39 100644 --- a/lightning_pose/utils/fiftyone.py +++ b/lightning_pose/utils/fiftyone.py @@ -1,15 +1,15 @@ +import os +from typing import Dict, List, Literal, Optional, Union + import fiftyone as fo -from tqdm import tqdm -from typing import Dict, List, Optional, Union, Callable, Any, Literal -import pandas as pd import numpy as np -from omegaconf import DictConfig, OmegaConf, ListConfig -import os +import pandas as pd +from omegaconf import DictConfig +from tqdm import tqdm from typeguard import typechecked -from lightning_pose.utils.io import get_videos_in_dir -from lightning_pose.utils.io import return_absolute_path, return_absolute_data_paths from lightning_pose.utils import pretty_print_str +from lightning_pose.utils.io import return_absolute_data_paths, return_absolute_path @typechecked @@ -294,7 +294,8 @@ def create_dataset(self) -> fo.Dataset: samples = [] # read each model's csv into a pandas dataframe self.load_model_predictions() - # assumes that train,test,val split is identical for all the different models. may be different with ensembling. + # assumes that train,test,val split is identical for all the different models + # may be different with ensembling self.data_tags = get_image_tags(self.preds_pandas_df_dict[self.model_names[0]]) # build the ground-truth keypoints per image gt_keypoints_list = self.get_gt_keypoints_list() diff --git a/lightning_pose/utils/io.py b/lightning_pose/utils/io.py index d282d13c..00c63573 100644 --- a/lightning_pose/utils/io.py +++ b/lightning_pose/utils/io.py @@ -1,10 +1,11 @@ """Path handling functions.""" -from omegaconf import DictConfig, ListConfig import os +from typing import List, Optional, Tuple, Union + import pandas as pd +from omegaconf import DictConfig, ListConfig from typeguard import typechecked -from typing import List, Tuple, Union, Optional @typechecked @@ -175,9 +176,13 @@ def get_videos_in_dir(video_dir: str, return_mp4_only: bool = True) -> List[str] assert os.path.isdir(video_dir) # get all video files in directory, from allowed formats allowed_formats = (".mp4", ".avi", ".mov") - if return_mp4_only == True: + if return_mp4_only: allowed_formats = ".mp4" - video_files = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(allowed_formats)] + video_files = [ + os.path.join(video_dir, f) + for f in os.listdir(video_dir) + if f.endswith(allowed_formats) + ] if len(video_files) == 0: raise IOError("Did not find any valid video files in %s" % video_dir) diff --git a/lightning_pose/utils/pca.py b/lightning_pose/utils/pca.py index a1ee7cc9..f66475ce 100644 --- a/lightning_pose/utils/pca.py +++ b/lightning_pose/utils/pca.py @@ -1,16 +1,17 @@ """PCA class to assist with computing PCA losses.""" +import warnings +from typing import Any, Dict, List, Literal, Optional, Union + import numpy as np -from omegaconf import DictConfig, ListConfig -from sklearn.decomposition import PCA import torch +from omegaconf import ListConfig +from sklearn.decomposition import PCA from torchtyping import TensorType from typeguard import typechecked -from typing import List, Optional, Union, Literal, Dict, Any -import warnings from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule -from lightning_pose.data.utils import clean_any_nans, DataExtractor +from lightning_pose.data.utils import DataExtractor, clean_any_nans from lightning_pose.losses.helpers import EmpiricalEpsilon, convert_dict_values_to_tensors _TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -139,9 +140,9 @@ def _choose_n_components(self) -> None: self._n_components_kept = 3 if self._n_components_kept != self.components_to_keep: warnings.warn( - "for {} loss, you specified {} components_to_keep, but we will instead keep {} components".format( - self.loss_type, self.components_to_keep, self._n_components_kept - ) + f"for {self.loss_type} loss, you specified {self.components_to_keep} " + f"components_to_keep, but we will instead keep {self._n_components_kept} " + f"components" ) elif self.loss_type == "pca_singleview": if self.pca_object is not None: @@ -263,17 +264,15 @@ def _check_components_to_keep(self) -> None: if type(self.components_to_keep) is int: if self.components_to_keep > self.fitted_pca_object.n_components_: raise ValueError( - "components_to_keep was set to {}, exceeding the maximum value of {} observation dims".format( - self.components_to_keep, self.fitted_pca_object.n_components_ - ) + f"components_to_keep was set to {self.components_to_keep}, exceeding the " + f"maximum value of {self.fitted_pca_object.n_components_} observation dims" ) # if float, ensure a proportion between 0.0-1.0 elif type(self.components_to_keep) is float: if self.components_to_keep < 0.0 or self.components_to_keep > 1.0: raise ValueError( - "components_to_keep was set to {} while it has to be between 0.0 and 1.0".format( - self.components_to_keep - ) + f"components_to_keep was set to {self.components_to_keep} while it has to be " + f"between 0.0 and 1.0" ) def _find_first_threshold_cross(self) -> int: diff --git a/lightning_pose/utils/predictions.py b/lightning_pose/utils/predictions.py index d808d3f2..6a04bfa6 100644 --- a/lightning_pose/utils/predictions.py +++ b/lightning_pose/utils/predictions.py @@ -1,35 +1,31 @@ """Functions for predicting keypoints on labeled datasets and unlabeled videos.""" +import os +import time +from typing import List, Optional, Tuple, Union + +import lightning.pytorch as pl import matplotlib.pyplot as plt import numpy as np -from omegaconf import DictConfig, OmegaConf -import os import pandas as pd -import lightning.pytorch as pl +import torch +from omegaconf import DictConfig, OmegaConf from pytorch_lightning import LightningModule -from pytorch_lightning import LightningDataModule from skimage.draw import disk -import time -import torch from torchtyping import TensorType from tqdm import tqdm from typeguard import typechecked -from typing import Dict, List, Literal, Optional, Tuple, Type, Union from lightning_pose.data.dali import LitDaliWrapper, PrepareDALI from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule from lightning_pose.data.utils import count_frames -from lightning_pose.models.heatmap_tracker import ( - HeatmapTracker, - SemiSupervisedHeatmapTracker, -) +from lightning_pose.models.heatmap_tracker import HeatmapTracker, SemiSupervisedHeatmapTracker from lightning_pose.models.regression_tracker import ( RegressionTracker, SemiSupervisedRegressionTracker, ) from lightning_pose.utils import pretty_print_str - _TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -117,7 +113,9 @@ def unpack_preds( ]: """unpack list of preds coming out from pl.trainer.predict, confs tuples into tensors. It still returns unnecessary final rows, which should be discarded at the dataframe stage. - This works for the output of predict_loader, suitable for batch_size=1, sequence_length=16, step=16""" + This works for the output of predict_loader, suitable for + batch_size=1, sequence_length=16, step=16 + """ # stack the predictions into rows. # loop over the batches, and stack stacked_preds = torch.vstack([pred[0] for pred in preds]) @@ -172,7 +170,9 @@ def fix_context_preds_confs( else: # we don't have as many predictions as frames; pad with final entry which is valid. n_pad = self.frame_count - preds_combined.shape[0] - preds_combined = torch.vstack([preds_combined, torch.tile(preds_combined[0], (n_pad, 1))]) + preds_combined = torch.vstack( + [preds_combined, torch.tile(preds_combined[0], (n_pad, 1))] + ) if zero_pad_confidence: # zeroing out those first and last two rows (after we've shifted everything above) @@ -578,11 +578,15 @@ def get_model_class(map_type: str, semi_supervised: bool) -> LightningModule: ) else: if map_type == "regression": - from lightning_pose.models.regression_tracker import SemiSupervisedRegressionTracker as Model + from lightning_pose.models.regression_tracker import ( + SemiSupervisedRegressionTracker as Model, + ) elif map_type == "heatmap": from lightning_pose.models.heatmap_tracker import SemiSupervisedHeatmapTracker as Model elif map_type == "heatmap_mhcrnn": - from lightning_pose.models.heatmap_tracker_mhcrnn import SemiSupervisedHeatmapTrackerMHCRNN as Model + from lightning_pose.models.heatmap_tracker_mhcrnn import ( + SemiSupervisedHeatmapTrackerMHCRNN as Model, + ) else: raise NotImplementedError( "%s is an invalid model_type for a semi-supervised model" % map_type @@ -616,10 +620,7 @@ def load_model_from_checkpoint( model as a Lightning Module """ - from lightning_pose.utils.io import ( - check_if_semi_supervised, - return_absolute_data_paths, - ) + from lightning_pose.utils.io import check_if_semi_supervised, return_absolute_data_paths from lightning_pose.utils.scripts import ( get_data_module, get_dataset, diff --git a/lightning_pose/utils/scripts.py b/lightning_pose/utils/scripts.py index 1bb07ee1..b544b869 100644 --- a/lightning_pose/utils/scripts.py +++ b/lightning_pose/utils/scripts.py @@ -1,52 +1,43 @@ """Helper functions to build pipeline components from config dictionary.""" +import os +from typing import Dict, Optional, Union + import imgaug.augmenters as iaa -from moviepy.editor import VideoFileClip +import lightning.pytorch as pl import numpy as np -from omegaconf import DictConfig, OmegaConf -import os import pandas as pd -import lightning.pytorch as pl +from moviepy.editor import VideoFileClip +from omegaconf import DictConfig, OmegaConf from typeguard import typechecked -from typing import Dict, Optional, Union from lightning_pose.callbacks import AnnealWeight -from lightning_pose.data.dali import PrepareDALI from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule from lightning_pose.data.datasets import BaseTrackingDataset, HeatmapDataset from lightning_pose.data.utils import compute_num_train_frames, split_sizes_from_probabilities from lightning_pose.losses.factory import LossFactory from lightning_pose.metrics import ( + pca_multiview_reprojection_error, + pca_singleview_reprojection_error, pixel_error, temporal_norm, - pca_singleview_reprojection_error, - pca_multiview_reprojection_error, -) -from lightning_pose.models.regression_tracker import ( - RegressionTracker, - SemiSupervisedRegressionTracker, -) -from lightning_pose.models.heatmap_tracker import ( - HeatmapTracker, - SemiSupervisedHeatmapTracker, ) +from lightning_pose.models.heatmap_tracker import HeatmapTracker, SemiSupervisedHeatmapTracker from lightning_pose.models.heatmap_tracker_mhcrnn import ( HeatmapTrackerMHCRNN, SemiSupervisedHeatmapTrackerMHCRNN, ) +from lightning_pose.models.regression_tracker import ( + RegressionTracker, + SemiSupervisedRegressionTracker, +) from lightning_pose.utils.io import ( check_if_semi_supervised, get_keypoint_names, return_absolute_path, - return_absolute_data_paths, ) from lightning_pose.utils.pca import KeypointPCA -from lightning_pose.utils.predictions import ( - load_model_from_checkpoint, - create_labeled_video, - PredictionHandler, - predict_single_video, -) +from lightning_pose.utils.predictions import create_labeled_video, predict_single_video @typechecked @@ -195,9 +186,10 @@ def get_dataset( cfg: DictConfig, data_dir: str, imgaug_transform: iaa.Sequential ) -> Union[BaseTrackingDataset, HeatmapDataset]: """Create a dataset that contains labeled data.""" - from PIL import Image import os + from PIL import Image + if cfg.model.model_type == "regression": dataset = BaseTrackingDataset( root_directory=data_dir, @@ -224,15 +216,10 @@ def get_dataset( cfg.data.image_orig_dims.height, ): raise ValueError( - f"image_orig_dims in data configuration file is (width=%i, height=%i) but" - f" your image size is (width=%i, height=%i). Please update the data " - f"configuration file" - % ( - cfg.data.image_orig_dims.width, - cfg.data.image_orig_dims.height, - image.size[0], - image.size[1], - ) + f"image_orig_dims in data configuration file is " + f"(width={cfg.data.image_orig_dims.width}, height={cfg.data.image_orig_dims.height}) " + f"but your image size is (width={image.size[0]}, height={image.size[1]}). " + f"Please update the data configuration file" ) return dataset @@ -312,8 +299,8 @@ def get_loss_factories( if loss_name[:8] == "unimodal" or loss_name[:15] == "temporal_heatmap": if cfg.model.model_type == "regression": raise NotImplementedError( - f"unimodal loss can only be used with classes inheriting from " - f"HeatmapTracker. \nYou specified a RegressionTracker model." + "unimodal loss can only be used with classes inheriting from " + "HeatmapTracker. \nYou specified a RegressionTracker model." ) # record original image dims (after initial resizing) height_og = cfg.data.image_resize_dims.height @@ -492,7 +479,9 @@ def get_callbacks( lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="epoch") callbacks.append(lr_monitor) if ckpt_model: - ckpt_callback = pl.callbacks.model_checkpoint.ModelCheckpoint(monitor="val_supervised_loss") + ckpt_callback = pl.callbacks.model_checkpoint.ModelCheckpoint( + monitor="val_supervised_loss" + ) callbacks.append(ckpt_callback) if backbone_unfreeze: transfer_unfreeze_callback = pl.callbacks.BackboneFinetuning( diff --git a/lightning_pose/utils/tests.py b/lightning_pose/utils/tests.py index 5424d5b9..7d4aeb8d 100644 --- a/lightning_pose/utils/tests.py +++ b/lightning_pose/utils/tests.py @@ -1,4 +1,5 @@ import gc + import torch from lightning_pose.utils.scripts import get_loss_factories, get_model diff --git a/scripts/converters/dlc2lp.py b/scripts/converters/dlc2lp.py index 430fffcb..503e6d25 100644 --- a/scripts/converters/dlc2lp.py +++ b/scripts/converters/dlc2lp.py @@ -1,10 +1,10 @@ import argparse import glob -import numpy as np import os -import pandas as pd import shutil +import numpy as np +import pandas as pd parser = argparse.ArgumentParser() parser.add_argument("--dlc_dir", type=str) diff --git a/scripts/create_fiftyone_dataset.py b/scripts/create_fiftyone_dataset.py index e0f5f1e0..acfb20ef 100755 --- a/scripts/create_fiftyone_dataset.py +++ b/scripts/create_fiftyone_dataset.py @@ -1,16 +1,16 @@ """Visualize predictions of models in a fiftyone dashboard.""" +import fiftyone as fo import hydra from omegaconf import DictConfig + +from lightning_pose.utils import pretty_print_str from lightning_pose.utils.fiftyone import ( + FiftyOneFactory, FiftyOneImagePlotter, FiftyOneKeypointVideoPlotter, check_dataset, - FiftyOneFactory, ) -import fiftyone as fo - -from lightning_pose.utils import pretty_print_str @hydra.main(config_path="configs", config_name="config_mirror-mouse-example") diff --git a/scripts/predict_new_vids.py b/scripts/predict_new_vids.py index 28d3bc2e..edfc0530 100755 --- a/scripts/predict_new_vids.py +++ b/scripts/predict_new_vids.py @@ -1,11 +1,12 @@ """Run inference on a list of models and videos.""" +import os + import hydra -from moviepy.editor import VideoFileClip +import lightning.pytorch as pl import numpy as np +from moviepy.editor import VideoFileClip from omegaconf import DictConfig, OmegaConf -import os -import lightning.pytorch as pl from typeguard import typechecked from lightning_pose.utils import get_gpu_list_from_cfg @@ -13,12 +14,16 @@ check_if_semi_supervised, ckpt_path_from_base_path, get_videos_in_dir, - return_absolute_path, return_absolute_data_paths, + return_absolute_path, ) from lightning_pose.utils.predictions import load_model_from_checkpoint -from lightning_pose.utils.scripts import get_imgaug_transform, get_dataset, get_data_module -from lightning_pose.utils.scripts import export_predictions_and_labeled_video +from lightning_pose.utils.scripts import ( + export_predictions_and_labeled_video, + get_data_module, + get_dataset, + get_imgaug_transform, +) """ this script will get two imporant args. model to use and video folder to process. hydra will orchestrate both. advanatages -- in the future we could parallelize to new machines. diff --git a/scripts/train_hydra.py b/scripts/train_hydra.py index 0afdc6c7..a44b9bbf 100755 --- a/scripts/train_hydra.py +++ b/scripts/train_hydra.py @@ -1,11 +1,12 @@ """Example model training script.""" -import hydra -from omegaconf import DictConfig import os + +import hydra import lightning.pytorch as pl +from omegaconf import DictConfig -from lightning_pose.utils import pretty_print_str, pretty_print_cfg +from lightning_pose.utils import pretty_print_cfg, pretty_print_str from lightning_pose.utils.io import ( check_video_paths, return_absolute_data_paths, @@ -13,15 +14,15 @@ ) from lightning_pose.utils.predictions import predict_dataset from lightning_pose.utils.scripts import ( + calculate_train_batches, + compute_metrics, export_predictions_and_labeled_video, + get_callbacks, get_data_module, get_dataset, get_imgaug_transform, get_loss_factories, get_model, - get_callbacks, - calculate_train_batches, - compute_metrics, ) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..349a5c34 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,16 @@ +[flake8] +max-line-length = 99 +ignore = F821, W503 +extend-ignore = E203 +exclude = + .git, + __pycache__, + __init__.py, + build, + dist, + docs/ + scripts/ + +[isort] +line_length = 99 +profile = black diff --git a/setup.py b/setup.py index 48d0f676..739a899a 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,8 @@ -#:!/usr/bin/env python -from setuptools import find_packages, setup -import subprocess + import re +import subprocess + +from setuptools import find_packages, setup VERSION = "0.0.2" # was previously None diff --git a/tests/conftest.py b/tests/conftest.py index 71c3aef3..eed74139 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,25 +6,26 @@ """ import copy -import imgaug.augmenters as iaa -from omegaconf import ListConfig, OmegaConf import os -import pytest -import lightning.pytorch as pl import shutil +from typing import Callable, List + +import imgaug.augmenters as iaa +import lightning.pytorch as pl +import pytest import torch -from typing import Callable, List, Optional import yaml +from omegaconf import OmegaConf from lightning_pose.data.dali import LitDaliWrapper, PrepareDALI from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule from lightning_pose.data.datasets import BaseTrackingDataset, HeatmapDataset from lightning_pose.utils.io import get_videos_in_dir from lightning_pose.utils.scripts import ( - get_imgaug_transform, - get_dataset, - get_data_module, get_callbacks, + get_data_module, + get_dataset, + get_imgaug_transform, ) TOY_DATA_ROOT_DIR = "data/mirror-mouse-example" diff --git a/tests/data/test_dali.py b/tests/data/test_dali.py index ce100d6c..c2b173f0 100644 --- a/tests/data/test_dali.py +++ b/tests/data/test_dali.py @@ -1,7 +1,5 @@ """Test dali dataloading functionality.""" -from nvidia.dali.plugin.pytorch import LastBatchPolicy -import pytest import torch from lightning_pose.data.dali import video_pipe @@ -36,9 +34,10 @@ def test_video_pipe(video_list): def test_PrepareDALI(cfg, video_list): - from lightning_pose.data.dali import PrepareDALI import os + from lightning_pose.data.dali import PrepareDALI + filenames = video_list assert os.path.isfile(filenames[0]) # base model: check we can build and run pipe and get a decent looking batch diff --git a/tests/data/test_datamodules.py b/tests/data/test_datamodules.py index 515f1cf7..75cd084c 100644 --- a/tests/data/test_datamodules.py +++ b/tests/data/test_datamodules.py @@ -1,9 +1,10 @@ """Test datamodule functionality.""" import pytest +import torch + # from pytorch_lightning.trainer.supporters import CombinedLoader from lightning.pytorch.utilities import CombinedLoader -import torch def test_base_datamodule(cfg, base_data_module): diff --git a/tests/data/test_datasets.py b/tests/data/test_datasets.py index 234c6238..8975a2bf 100644 --- a/tests/data/test_datasets.py +++ b/tests/data/test_datasets.py @@ -1,6 +1,5 @@ """Test basic dataset functionality.""" -import pytest import torch diff --git a/tests/data/test_utils.py b/tests/data/test_utils.py index 21440ace..075cf951 100644 --- a/tests/data/test_utils.py +++ b/tests/data/test_utils.py @@ -1,10 +1,10 @@ """Test data utils functionality.""" import copy -from kornia.geometry.subpix import spatial_softmax2d, spatial_expectation2d -from kornia.geometry.transform import pyrup + import pytest import torch +from kornia.geometry.subpix import spatial_expectation2d, spatial_softmax2d from lightning_pose.data.utils import generate_heatmaps @@ -85,9 +85,10 @@ def test_clean_any_nans(): def test_count_frames(video_list): - from lightning_pose.data.utils import count_frames import cv2 + from lightning_pose.data.utils import count_frames + num_frames = 0 for video_file in video_list: cap = cv2.VideoCapture(video_file) @@ -173,7 +174,7 @@ def test_generate_heatmaps(cfg, heatmap_dataset): def test_generate_uniform_heatmaps(cfg, toy_data_dir): - from lightning_pose.utils.scripts import get_imgaug_transform, get_dataset + from lightning_pose.utils.scripts import get_dataset, get_imgaug_transform # update config cfg_tmp = copy.deepcopy(cfg) @@ -194,7 +195,7 @@ def test_generate_uniform_heatmaps(cfg, toy_data_dir): batch = heatmap_dataset.__getitem__(idx=0) heatmap_gt = batch["heatmaps"].unsqueeze(0) keypts_gt = batch["keypoints"].unsqueeze(0).reshape(1, -1, 2) - + heatmap_uniform_torch = generate_heatmaps( keypts_gt, height=im_height, @@ -230,7 +231,7 @@ def test_generate_uniform_heatmaps(cfg, toy_data_dir): def test_generate_heatmaps_weird_shape(cfg, toy_data_dir): - from lightning_pose.utils.scripts import get_imgaug_transform, get_dataset + from lightning_pose.utils.scripts import get_dataset, get_imgaug_transform img_shape = (384, 256) diff --git a/tests/losses/test_helpers.py b/tests/losses/test_helpers.py index ef3c3482..35e48668 100644 --- a/tests/losses/test_helpers.py +++ b/tests/losses/test_helpers.py @@ -1,7 +1,6 @@ """Test loss helper functions.""" import numpy as np -import pytest import torch diff --git a/tests/losses/test_losses.py b/tests/losses/test_losses.py index d0b9242d..3fc9f4a4 100644 --- a/tests/losses/test_losses.py +++ b/tests/losses/test_losses.py @@ -1,13 +1,11 @@ """Test loss classes.""" -import torch import numpy as np import pytest -import yaml +import torch from lightning_pose.utils.pca import format_multiview_data_for_pca - stage = "train" device = "cpu" @@ -382,6 +380,7 @@ def test_unimodal_mse_loss(): def test_unimodal_kl_loss(): from kornia.geometry.subpix import spatial_softmax2d + from lightning_pose.losses.losses import UnimodalLoss img_size = 48 @@ -426,6 +425,7 @@ def test_unimodal_kl_loss(): def test_unimodal_js_loss(): from kornia.geometry.subpix import spatial_softmax2d + from lightning_pose.losses.losses import UnimodalLoss img_size = 48 diff --git a/tests/models/test_base.py b/tests/models/test_base.py index fda93022..63143fc5 100644 --- a/tests/models/test_base.py +++ b/tests/models/test_base.py @@ -1,13 +1,11 @@ """Test functionality of base model classes.""" -import pytest import segment_anything import torch import torchvision from lightning_pose.models.base import BaseFeatureExtractor - _TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" BATCH_SIZE = 2 diff --git a/tests/models/test_heatmap_tracker.py b/tests/models/test_heatmap_tracker.py index a41550cf..6dbfa8f9 100644 --- a/tests/models/test_heatmap_tracker.py +++ b/tests/models/test_heatmap_tracker.py @@ -1,7 +1,6 @@ """Test the initialization and training of heatmap models.""" import copy -import pytest from lightning_pose.utils.tests import run_model_test diff --git a/tests/models/test_heatmap_tracker_mhcrnn.py b/tests/models/test_heatmap_tracker_mhcrnn.py index 12a7c2dd..c6b9ef75 100644 --- a/tests/models/test_heatmap_tracker_mhcrnn.py +++ b/tests/models/test_heatmap_tracker_mhcrnn.py @@ -1,7 +1,6 @@ """Test the initialization and training of context heatmap multi-head crnn models.""" import copy -import pytest from lightning_pose.utils.tests import run_model_test diff --git a/tests/models/test_regression_tracker.py b/tests/models/test_regression_tracker.py index f97faf24..b58085af 100644 --- a/tests/models/test_regression_tracker.py +++ b/tests/models/test_regression_tracker.py @@ -1,7 +1,6 @@ """Test the initialization and training of regression models.""" import copy -import pytest from lightning_pose.utils.tests import run_model_test diff --git a/tests/utils/test_pca.py b/tests/utils/test_pca.py index 09b50bc8..6d08398e 100644 --- a/tests/utils/test_pca.py +++ b/tests/utils/test_pca.py @@ -3,10 +3,10 @@ import numpy as np import pytest import torch +from lightning.pytorch.utilities import CombinedLoader from lightning_pose.utils.pca import KeypointPCA -# from pytorch_lightning.trainer.supporters import CombinedLoader -from lightning.pytorch.utilities import CombinedLoader + def check_lists_equal(list_1, list_2): return len(list_1) == len(list_2) and sorted(list_1) == sorted(list_2) @@ -19,7 +19,9 @@ def test_train_loader_iter(base_data_module_combined): dataset_length = len(base_data_module_combined.train_dataset) combined_loader = base_data_module_combined.train_dataloader() - # the default mode of CombinedLoader changes in Lightning 2.0. we manually take the iterbles inside the combined_loader, and make a new class that cycles only over the labeled dataloader. + # the default mode of CombinedLoader changes in Lightning 2.0 + # we manually take the iterbles inside the combined_loader, and make a new class that cycles + # only over the labeled dataloader combined_loader = CombinedLoader(combined_loader.iterables, mode="min_size") image_counter = 0 for i, batch in enumerate(combined_loader): @@ -187,7 +189,6 @@ def test_component_chooser(): # create fake data for PCA from sklearn.datasets import load_diabetes from sklearn.decomposition import PCA - import numpy as np diabetes = load_diabetes() data_for_pca = diabetes.data