Skip to content

Commit

Permalink
Add main function for high level pipeline (#41)
Browse files Browse the repository at this point in the history
* Update base_length_ratio results to scalar and return nan if length=0.

* convert the data in scalar column to the value without [].

* Modify stem width and base-related functions for rice traits.

* change the argument name of 'lateral_only' to 'monocots'.

* change the argument name of 'lateral_only' to 'monocots' (cont.).

* Add function for getting final csv from multiple plants `get_all_plants_traits`

* Add tests for all plant summary

* Modified overwrite parameter to write_per_plant

* Changed plant_name to just the h5 name and not full path

* modify functions to set nan values for base-related traits for rice.

* Modify `get_base_median_ratio` function for all nans.

* Modify `get_traits_value_plant_summary` for traits with no value [].

* Delete monocot test for now

* Fix docstring indentations

* Fix graphpipeline tests

* Add rice test

* Add warning filter for ellipse module

* Add option to change csv suffixes. Import modules in __init__.

* Lint

* Update comment

---------

Co-authored-by: Lin Wang <[email protected]>
Co-authored-by: Talmo Pereira <[email protected]>
  • Loading branch information
3 people authored Jul 5, 2023
1 parent 78e5116 commit f5dab2b
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 42 deletions.
7 changes: 6 additions & 1 deletion sleap_roots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
import sleap_roots.convhull
import sleap_roots.ellipse
import sleap_roots.networklength
import sleap_roots.points
import sleap_roots.scanline
import sleap_roots.series
import sleap_roots.summary
import sleap_roots.traitsgraph
import sleap_roots.graphpipeline
from sleap_roots.graphpipeline import get_all_plants_traits
from sleap_roots.series import Series

# Define package version.
# This is read dynamically by setuptools in setup.cfg to determine the release version.
# This is read dynamically by setuptools in pyproject.toml to determine the release version.
__version__ = "0.0.1"
169 changes: 133 additions & 36 deletions sleap_roots/graphpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import numpy as np
import pandas as pd
import os
from typing import List
from fractions import Fraction
from pathlib import Path
from sleap_roots.traitsgraph import get_traits_graph
from sleap_roots.angle import get_root_angle
from sleap_roots.bases import (
Expand Down Expand Up @@ -47,7 +50,7 @@
get_scanline_first_ind,
get_scanline_last_ind,
)
from sleap_roots.series import Series
from sleap_roots.series import Series, find_all_series
from sleap_roots.summary import get_summary
from sleap_roots.tips import get_tips, get_tip_xs, get_tip_ys
from typing import Dict, Tuple
Expand Down Expand Up @@ -128,6 +131,12 @@
message="invalid value encountered in double_scalars",
category=RuntimeWarning,
)
warnings.filterwarnings(
"ignore",
message="invalid value encountered in scalar divide",
category=RuntimeWarning,
module="ellipse",
)


def get_traits_value_frame(
Expand Down Expand Up @@ -293,27 +302,36 @@ def get_traits_value_plant(
n_line: int = 50,
network_fraction: float = 2 / 3,
write_csv: bool = False,
csv_name: str = "plant_original_traits.csv",
) -> Tuple[Dict, pd.DataFrame]:
"""Get SLEAP traits per plant based on graph.
csv_suffix: str = ".traits.csv",
) -> Tuple[Dict, pd.DataFrame, str]:
"""Get detailed SLEAP traits for every frame of a plant, based on the graph.
Args:
h5: h5 file, plant image series.
monocots: Boolean value, where false is dicot (default), true is rice.
primary_name: primary model name.
lateral_name: lateral model name.
stem_width_tolerance: difference in projection norm between right and left side.
n_line: number of scan lines, np.nan for no interaction.
network_fraction: length found in the lower fration value of the network.
write_csv: Boolean value, where true is write csv file.
csv_name: saved csv file name.
Return:
Tuple of a dictionary and a DataFrame with all traits per plant.
h5: The h5 file representing the plant image series.
monocots: A boolean value indicating whether the plant is a monocot (True)
or a dicot (False) (default).
primary_name: Name of the primary root predictions. The predictions file is
expected to be named `"{h5_path}.{primary_name}.predictions.slp"`.
lateral_name: Name of the lateral root predictions. The predictions file is
expected to be named `"{h5_path}.{lateral_name}.predictions.slp"`.
stem_width_tolerance: The difference in the projection norm between
the right and left side of the stem.
n_line: The number of scan lines. Use np.nan for no interaction.
network_fraction: The length found in the lower fraction value of the network.
write_csv: A boolean value. If True, it writes per plant detailed
CSVs with traits for every instance on every frame.
csv_suffix: If write_csv=True, the CSV file will be saved with the
h5 path + csv_suffix.
Returns:
A tuple containing a dictionary and a DataFrame with all traits per plant,
and the plant name. The Dataframe has root traits per instance and frame
where each row corresponds to a frame in the H5 file. The plant_name is
given by the h5 file.
"""
plant = Series.load(h5, primary_name=primary_name, lateral_name=lateral_name)
plant_name = plant.series_name
# get nymber of frames per plant
# get number of frames per plant
n_frame = len(plant)

data_plant = []
Expand Down Expand Up @@ -384,9 +402,9 @@ def get_traits_value_plant(
)

if write_csv:
csv_name = "plant_original_traits_" + plant_name + ".csv"
csv_name = Path(h5).with_suffix(f"{csv_suffix}")
data_plant_df.to_csv(csv_name, index=False)
return data_plant, data_plant_df
return data_plant, data_plant_df, plant_name


def get_traits_value_plant_summary(
Expand All @@ -398,29 +416,37 @@ def get_traits_value_plant_summary(
n_line: int = 50,
network_fraction: float = 2 / 3,
write_csv: bool = False,
csv_name: str = "plant_original_traits.csv",
csv_suffix: str = ".traits.csv",
write_summary_csv: bool = False,
summary_csv_name: str = "plant_summary_traits.csv",
summary_csv_suffix: str = ".summary_traits.csv",
) -> pd.DataFrame:
"""Get summarized SLEAP traits per plant based on graph.
"""Get summary statistics of SLEAP traits per plant based on graph.
Args:
h5: h5 file, plant image series.
monocots: Boolean value, where false is dicot (default), true is rice.
primary_name: primary model name.
lateral_name: lateral model name.
stem_width_tolerance: difference in projection norm between right and left side.
n_line: number of scan lines, np.nan for no interaction.
network_fraction: length found in the lower fration value of the network.
write_csv: Boolean value, where true is write csv file.
csv_name: saved csv file name.
h5: The h5 file representing the plant image series.
monocots: A boolean value indicating whether the plant is a monocot (True)
or a dicot (False) (default).
primary_name: Name of the primary root predictions. The predictions file is
expected to be named `"{h5_path}.{primary_name}.predictions.slp"`.
lateral_name: Name of the lateral root predictions. The predictions file is
expected to be named `"{h5_path}.{lateral_name}.predictions.slp"`.
stem_width_tolerance: The difference in the projection norm between
the right and left side of the stem.
n_line: The number of scan lines. Use np.nan for no interaction.
network_fraction: The length found in the lower fraction value of the network.
write_csv: A boolean value. If True, it writes per plant detailed
CSVs with traits for every instance on every frame.
csv_suffix: If write_csv=True, the CSV file will be saved with the name
h5 path + csv_suffix.
write_summary_csv: Boolean value, where true is write summarized csv file.
summary_csv_name: saved summarized csv file name.
summary_csv_suffix: If write_summary_csv=True, the CSV file with the summary
statistics per plant will be saved with the name
h5 path + summary_csv_suffix.
Return:
A DataFrame with all summarized traits per plant.
A DataFrame with summary statistics of all traits per plant.
"""
data_plant, data_plant_df = get_traits_value_plant(
data_plant, data_plant_df, plant_name = get_traits_value_plant(
h5,
monocots,
primary_name,
Expand All @@ -429,7 +455,7 @@ def get_traits_value_plant_summary(
n_line,
network_fraction,
write_csv,
csv_name,
csv_suffix,
)

# get summarized non-scalar traits per frame
Expand Down Expand Up @@ -602,7 +628,7 @@ def get_traits_value_plant_summary(
data_plant_frame_summary[
data_plant_frame_summary_key[j] + "_prc95"
] = trait_prc95
data_plant_frame_summary["plant_name"] = [os.path.splitext(h5)[0]]
data_plant_frame_summary["plant_name"] = [plant_name]
data_plant_frame_summary_df = pd.DataFrame(data_plant_frame_summary)

# reorganize the column position
Expand All @@ -611,5 +637,76 @@ def get_traits_value_plant_summary(
data_plant_frame_summary_df = data_plant_frame_summary_df[column_names]

if write_summary_csv:
summary_csv_name = Path(h5).with_suffix(f"{summary_csv_suffix}")
data_plant_frame_summary_df.to_csv(summary_csv_name, index=False)
return data_plant_frame_summary_df


def get_all_plants_traits(
data_folders: List[str],
primary_name: str,
lateral_name: str,
stem_width_tolerance: float = 0.02,
n_line: int = 50,
network_fraction: Fraction = Fraction(2, 3),
write_per_plant_details: bool = False,
per_plant_details_csv_suffix: str = ".traits.csv",
write_per_plant_summary: bool = False,
per_plant_summary_csv_suffix: str = ".summary_traits.csv",
monocots: bool = False,
all_plants_csv_name: str = "all_plants_traits.csv",
) -> pd.DataFrame:
"""Get a DataFrame with summary traits from all plants in the given data folders.
Args:
h5: The h5 file representing the plant image series.
monocots: A boolean value indicating whether the plant is a monocot (True)
or a dicot (False) (default).
primary_name: Name of the primary root predictions. The predictions file is
expected to be named `"{h5_path}.{primary_name}.predictions.slp"`.
lateral_name: Name of the lateral root predictions. The predictions file is
expected to be named `"{h5_path}.{lateral_name}.predictions.slp"`.
stem_width_tolerance: The difference in the projection norm between
the right and left side of the stem.
n_line: The number of scan lines. Use np.nan for no interaction.
network_fraction: The length found in the lower fraction value of the network.
write_per_plant_details: A boolean value. If True, it writes per plant detailed
CSVs with traits for every instance.
per_plant_details_csv_suffix: If write_csv=True, the CSV file will be saved
with the name h5 path + csv_suffix.
write_per_plant_summary: A boolean value. If True, it writes per plant summary
CSVs.
per_plant_summary_csv_suffix: If write_summary_csv=True, the CSV file with the
summary statistics per plant will be saved with the name
h5 path + summary_csv_suffix.
all_plants_csv_name: The name of the output CSV file containing all plants'
summary traits.
Returns:
A pandas DataFrame with summary root traits for all plants in the data folders.
Each row is a sample.
"""
h5_series = find_all_series(data_folders)

all_traits = []
for h5 in h5_series:
plant_traits = get_traits_value_plant_summary(
h5,
monocots=monocots,
primary_name=primary_name,
lateral_name=lateral_name,
stem_width_tolerance=stem_width_tolerance,
n_line=n_line,
network_fraction=network_fraction,
write_csv=write_per_plant_details,
csv_suffix=per_plant_details_csv_suffix,
write_summary_csv=write_per_plant_summary,
summary_csv_suffix=per_plant_summary_csv_suffix,
)
plant_traits["path"] = h5
all_traits.append(plant_traits)

all_traits_df = pd.concat(all_traits, ignore_index=True)

all_traits_df.to_csv(all_plants_csv_name, index=False)
return all_traits_df
2 changes: 1 addition & 1 deletion sleap_roots/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def load(
@property
def series_name(self) -> str:
"""Name of the series derived from the HDF5 filename."""
return Path(self.h5_path).stem
return Path(self.h5_path).name.split(".")[0]

@property
def video(self) -> sio.Video:
Expand Down
42 changes: 38 additions & 4 deletions tests/test_graphpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
get_traits_value_frame,
get_traits_value_plant,
get_traits_value_plant_summary,
get_all_plants_traits,
)
import pytest
import numpy as np
import pandas as pd


@pytest.fixture
Expand Down Expand Up @@ -70,7 +72,7 @@ def test_get_traits_value_frame(primary_pts, lateral_pts):
def test_get_traits_value_plant(canola_h5):
monocots = False

data_plant, data_plant_df = get_traits_value_plant(
data_plant, data_plant_df, plant_name = get_traits_value_plant(
canola_h5,
monocots,
primary_name="primary_multi_day",
Expand All @@ -79,10 +81,10 @@ def test_get_traits_value_plant(canola_h5):
n_line=50,
network_fraction=2 / 3,
write_csv=False,
csv_name="plant_original_traits.csv",
)
assert len(data_plant) == 72
assert data_plant_df.shape[1] == 45
assert plant_name == "919QDUH"


def test_get_traits_value_plant_summary(canola_h5):
Expand All @@ -96,10 +98,42 @@ def test_get_traits_value_plant_summary(canola_h5):
n_line=50,
network_fraction=2 / 3,
write_csv=False,
csv_name="plant_original_traits.csv",
write_summary_csv=False,
summary_csv_name="plant_summary_traits.csv",
)
assert data_plant_summary.shape[0] == 1
assert data_plant_summary.shape[1] == 1036
np.testing.assert_almost_equal(data_plant_summary.iloc[0, 5], 16.643764612148875)


def test_get_all_plants_traits_dicot(canola_folder):
data_folders = [canola_folder]
primary_name = "primary_multi_day"
lateral_name = "lateral_3_nodes"
write_per_plant_details = True
write_per_plant_summary = True
all_traits_df = get_all_plants_traits(
data_folders=data_folders,
primary_name=primary_name,
lateral_name=lateral_name,
write_per_plant_details=write_per_plant_details,
write_per_plant_summary=write_per_plant_summary,
)
assert all_traits_df.shape == (1, 1037)
np.testing.assert_almost_equal(all_traits_df.iloc[0, 5], 16.643764612148875)


def tests_get_all_plants_traits_monocot(rice_folder):
data_folders = [rice_folder]
primary_name = "longest_3do_6nodes"
lateral_name = "main_3do_6nodes"
write_per_plant_details = True
write_per_plant_summary = True
all_traits_df = get_all_plants_traits(
data_folders=data_folders,
primary_name=primary_name,
lateral_name=lateral_name,
write_per_plant_details=write_per_plant_details,
write_per_plant_summary=write_per_plant_summary,
)
assert all_traits_df.shape == (1, 1037)
np.testing.assert_almost_equal(all_traits_df.iloc[0, 5], 3.716619501198254)

0 comments on commit f5dab2b

Please sign in to comment.