Skip to content

Commit

Permalink
linting + docs (#102)
Browse files Browse the repository at this point in the history
* initial contrib docs

* isort

* flake8

* isort

* fixes
  • Loading branch information
themattinthehatt authored Jul 5, 2023
1 parent 6bfedd0 commit 1daba74
Show file tree
Hide file tree
Showing 46 changed files with 625 additions and 427 deletions.
2 changes: 2 additions & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ where to store the model code,
how to make it visible to users,
how to provide hyperparameters for model construction,
and how to connect it to data loaders and losses

* [General contributing guidelines](contributing.md): pull requests, testing, linting, etc.
43 changes: 43 additions & 0 deletions docs/contributing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Contributing

We welcome community contributions to the Lightning Pose repo!
If you have found a bug or would like to request a minor change, please
[open an issue](https://github.com/danbider/lightning-pose/issues).

In order to contribute code to the repo, please follow the steps below.

You will also need to install the following dev packages:
```bash
pip install flake8 isort
```

### Create a pull request
Please fork the Lightning Pose repo, make your changes, and then
[open a pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)
from your fork. Please read through the rest of this document before submitting the request.

### Linting
Linters automatically find (and sometimes fix) formatting issues in the code. We use two, which
are run from the command line in the Lightning Pose repo:

* `flake8`: warns of syntax errors, possible bugs, stylistic errors, etc. Please fix these!
```bash
flake8 .
```

* `isort`: automatically sorts import statements
```bash
isort .
```

### Testing
We currently do not have a continuous integration (CI) setup for the Lightning Pose repo due to its
reliance on GPUs (and the relative expense of CI services that provide GPU machines for testing).
Therefore, it is imperative that you run the unit tests yourself and verify that all tests have
passed before submitting your request (and upon each new push to that request).

To run the tests locally, you must have access to a GPU. Navigate to the Lightning Pose directory
and simply run
```bash
pytest
```
61 changes: 35 additions & 26 deletions lightning_pose/apps/labeled_frame_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,30 @@
Refer to apps.md for information on how to use this file.
streamlit run labeled_frame_diagnostics.py -- --model_dir "/home/zeus/content/Pose-app/data/demo/models"
streamlit run labeled_frame_diagnostics.py -- --model_dir ".../Pose-app/data/demo/models"
"""

import argparse
import copy
import numpy as np
import streamlit as st
import seaborn as sns
import pandas as pd
import os
from collections import defaultdict
from pathlib import Path

import numpy as np
from collections import defaultdict
import os
import pandas as pd
import seaborn as sns
import streamlit as st

from lightning_pose.apps.utils import build_precomputed_metrics_df, get_df_box, get_df_scatter
from lightning_pose.apps.utils import update_labeled_file_list
from lightning_pose.apps.utils import get_model_folders, get_model_folders_vis
from lightning_pose.apps.plots import make_seaborn_catplot, make_plotly_scatterplot, get_y_label
from lightning_pose.apps.plots import make_plotly_catplot
from lightning_pose.apps.plots import get_y_label, make_plotly_catplot, make_plotly_scatterplot
from lightning_pose.apps.utils import (
build_precomputed_metrics_df,
get_df_box,
get_df_scatter,
get_model_folders,
get_model_folders_vis,
update_labeled_file_list,
)

# catplot_options = ["boxen", "box", "bar", "violin", "strip"] # for seaborn
catplot_options = ["box", "violin", "strip"] # for plotly
Expand All @@ -31,7 +35,6 @@


def run():

args = parser.parse_args()

st.title("Labeled Frame Diagnostics")
Expand All @@ -48,15 +51,15 @@ def run():
# add a multiselect that shows existing model folders, and allows the user to de-select models
# assume we have args.model_dir and we search two levels down for model folders
model_folders = get_model_folders(args.model_dir)

# get the last two levels of each path to be presented to user
model_folders_vis = get_model_folders_vis(model_folders)

selected_models_vis = st.sidebar.multiselect("Select models", model_folders_vis, default=None)

# append this to full path
selected_models = ["/" + os.path.join(args.model_dir, f) for f in selected_models_vis]

# search for prediction files in the selected model folders
prediction_files = update_labeled_file_list(selected_models, use_ood=args.use_ood)

Expand All @@ -65,7 +68,6 @@ def run():
model_names = copy.copy(selected_models_vis)

if len(prediction_files) > 0: # otherwise don't try to proceed

# ---------------------------------------------------
# load data
# ---------------------------------------------------
Expand Down Expand Up @@ -107,9 +109,11 @@ def run():

# concat dataframes, collapsing hierarchy and making df fatter.
keypoint_names = list(
[c[0] for c in dframes_metrics[new_names[0]]["confidence"].columns[1::3]])
[c[0] for c in dframes_metrics[new_names[0]]["confidence"].columns[1::3]]
)
df_metrics = build_precomputed_metrics_df(
dframes=dframes_metrics, keypoint_names=keypoint_names)
dframes=dframes_metrics, keypoint_names=keypoint_names
)
metric_options = list(df_metrics.keys())

# ---------------------------------------------------'
Expand All @@ -122,7 +126,8 @@ def run():
with col0:
# choose from individual keypoints, their mean, or all at once
keypoint_to_plot = st.selectbox(
"Keypoint:", ["mean", "ALL", *keypoint_names], key="keypoint")
"Keypoint:", ["mean", "ALL", *keypoint_names], key="keypoint"
)

with col1:
# choose which metric to plot
Expand All @@ -142,7 +147,6 @@ def run():
# ---------------------------------------------------

with sup_col00:

st.header("Compare multiple models")

# enumerate plotting options
Expand All @@ -158,7 +162,9 @@ def run():
plot_scale = st.radio("Y-axis scale", scale_options, horizontal=True)

# filter data
df_metrics_filt = df_metrics[metric_to_plot][df_metrics[metric_to_plot].set == data_type]
df_metrics_filt = df_metrics[metric_to_plot][
df_metrics[metric_to_plot].set == data_type
]
n_frames_per_dtype = df_metrics_filt.shape[0] // len(selected_models)

# plot data
Expand All @@ -180,15 +186,19 @@ def run():
else:
top = 0.9
fig_box.fig.subplots_adjust(top=top)
fig_box.fig.suptitle("All keypoints (%i %s frames)" % (n_frames_per_dtype, data_type))
fig_box.fig.suptitle(
f"All keypoints ({n_frames_per_dtype} {data_type} frames)"
)
st.pyplot(fig_box)

else:

st.markdown("###")
fig_box = make_plotly_catplot(
x="model_name", y=keypoint_to_plot,
data=df_metrics_filt[df_metrics_filt[keypoint_to_plot] > int(plot_epsilon)],
x="model_name",
y=keypoint_to_plot,
data=df_metrics_filt[
df_metrics_filt[keypoint_to_plot] > int(plot_epsilon)
],
x_label="Model name",
y_label=y_label,
title=title,
Expand All @@ -208,7 +218,6 @@ def run():
# scatterplots
# ---------------------------------------------------
with sup_col01:

st.header("Compare two models")

col6, col7, col8 = st.columns(3)
Expand Down
Loading

0 comments on commit 1daba74

Please sign in to comment.