Skip to content

Commit

Permalink
Compute grouped summaries for multiple dicots
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed Mar 30, 2024
1 parent 9bc5e18 commit bc12818
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 21 deletions.
88 changes: 76 additions & 12 deletions sleap_roots/trait_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ def compute_multiple_dicots_traits(
"""
# Initialize the return structure with the series name and group
result = {
"series": series.series_name,
"group": series.group,
"series": str(series.series_name),
"group": str(series.group),
"traits": {},
"summary_stats": {},
}
Expand Down Expand Up @@ -528,7 +528,6 @@ def compute_multiple_dicots_traits_for_groups(
# Group series by their group property
series_groups = {}
for series in series_list:
print(f"Grouping series '{series.series_name}'")
group_name = str(series.group)
if group_name not in series_groups:
series_groups[group_name] = {"names": [], "series": []}

Check warning on line 533 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L529-L533

Added lines #L529 - L533 were not covered by tests
Expand All @@ -537,7 +536,6 @@ def compute_multiple_dicots_traits_for_groups(
series_groups[group_name]["series"].append(series) # Store Series objects

Check warning on line 536 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L535-L536

Added lines #L535 - L536 were not covered by tests

for group_name, group_data in series_groups.items():

Check warning on line 538 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L538

Added line #L538 was not covered by tests
print(f"Initializing group '{group_name}'")
# Initialize the return structure with the group name
group_result = {

Check warning on line 540 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L540

Added line #L540 was not covered by tests
"group": group_name,
Expand All @@ -558,7 +556,7 @@ def compute_multiple_dicots_traits_for_groups(
else:
aggregated_traits[trait].append([np.atleast_1d(values)])
group_result["traits"] = aggregated_traits
print(f"Group results: {group_result}")
print(f"Finished processing group '{group_name}'")

Check warning on line 559 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L557-L559

Added lines #L557 - L559 were not covered by tests

# Write to JSON if requested
if write_json:
Expand Down Expand Up @@ -594,7 +592,7 @@ def compute_multiple_dicots_traits_for_groups(
csv_name = f"{group_name}{csv_suffix}"
try:
summary_df = pd.DataFrame([summary_stats])
summary_df.insert(0, "group", group_name)
summary_df.insert(0, "genotype", group_name)
summary_df.to_csv(csv_name, index=False)
print(

Check warning on line 597 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L591-L597

Added lines #L591 - L597 were not covered by tests
f"Summary statistics for group {group_name} saved to {csv_name}"
Expand All @@ -607,19 +605,13 @@ def compute_multiple_dicots_traits_for_groups(
def compute_batch_traits(
self,
plants: List[Series],
write_json_per_series: bool = False,
json_suffix: str = ".all_frames_traits.json",
write_csv: bool = False,
csv_path: str = "traits.csv",
) -> pd.DataFrame:
"""Compute traits for a batch of plants.
Args:
plants: List of `Series` objects.
write_json_per_series: If `True`, write the computed traits to a JSON file
for each series.
json_suffix: The suffix to append to the JSON file name. Default is
".all_frames_traits.json".
write_csv: If `True`, write the computed traits to a CSV file.
csv_path: Path to write the CSV file to.
Expand Down Expand Up @@ -698,6 +690,78 @@ def compute_batch_multiple_dicots_traits(

return all_series_summaries_df

def compute_batch_multiple_dicots_traits_for_groups(
self,
all_series: List[Series],
write_json: bool = False,
write_csv: bool = False,
csv_path: str = "group_summarized_traits.csv",
) -> pd.DataFrame:
"""Compute traits for a batch of grouped series with multiple dicots.
Args:
all_series: List of `Series` objects.
write_json: If `True`, write each set of group traits to a JSON file.
write_csv: If `True`, write the computed traits to a CSV file.
csv_path: Path to write the CSV file to.
Returns:
A pandas DataFrame of computed traits summarized over all frames of each
series. The resulting dataframe will have a row for each series and a column
for each series-level summarized trait.
Summarized traits are prefixed with the trait name and an underscore,
followed by the summary statistic.
"""
# Check if the input list is empty
if not all_series:
raise ValueError("The input list 'all_series' is empty.")

Check warning on line 718 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L717-L718

Added lines #L717 - L718 were not covered by tests

try:

Check warning on line 720 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L720

Added line #L720 was not covered by tests
# Compute traits for each group of series
grouped_results = self.compute_multiple_dicots_traits_for_groups(

Check warning on line 722 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L722

Added line #L722 was not covered by tests
all_series, write_json=write_json, write_csv=False
)
except Exception as e:
raise RuntimeError(f"Error computing traits for groups: {e}")

Check warning on line 726 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L725-L726

Added lines #L725 - L726 were not covered by tests

# Prepare the list of dictionaries for the DataFrame
all_group_summaries = []
for group_result in grouped_results:

Check warning on line 730 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L729-L730

Added lines #L729 - L730 were not covered by tests
# Validate the expected key exists in the result
if "summary_stats" not in group_result:
raise KeyError(

Check warning on line 733 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L732-L733

Added lines #L732 - L733 were not covered by tests
"Expected key 'summary_stats' not found in group result."
)

# Assuming 'group' key exists in group_result and it indicates the genotype
genotype = group_result.get(

Check warning on line 738 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L738

Added line #L738 was not covered by tests
"group", "Unknown Genotype"
) # Default to "Unknown Genotype" if not found

# Start with a dictionary containing the genotype
group_summary = {"genotype": genotype}

Check warning on line 743 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L743

Added line #L743 was not covered by tests

# Add each trait statistic from the summary_stats dictionary to the group_summary
# This assumes summary_stats is a dictionary where keys are trait names and values are the statistics
for trait, statistic in group_result["summary_stats"].items():
group_summary[trait] = statistic

Check warning on line 748 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L747-L748

Added lines #L747 - L748 were not covered by tests

all_group_summaries.append(group_summary)

Check warning on line 750 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L750

Added line #L750 was not covered by tests

# Create a DataFrame from the list of dictionaries
all_group_summaries_df = pd.DataFrame(all_group_summaries)

Check warning on line 753 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L753

Added line #L753 was not covered by tests

# Write to CSV if requested
if write_csv:
try:
all_group_summaries_df.to_csv(csv_path, index=False)
print(f"Computed traits for all groups saved to {csv_path}")
except Exception as e:
raise IOError(f"Failed to write computed traits to CSV: {e}")

Check warning on line 761 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L756-L761

Added lines #L756 - L761 were not covered by tests

return all_group_summaries_df

Check warning on line 763 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L763

Added line #L763 was not covered by tests


@attrs.define
class DicotPipeline(Pipeline):
Expand Down
26 changes: 18 additions & 8 deletions tests/test_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,13 +376,23 @@ def test_root_width_canola(canola_h5):
np.array([[0, 0], [1, 1]]),
np.array([[[0, 0], [1, 1]], [[1, 1], [2, 2]]]),
0.02,
(np.array([]), [(np.nan, np.nan)], np.empty((0, 2)), np.empty((0, 2))),
(
np.nan,
[(np.nan, np.nan)],
np.full((1, 2), np.nan),
np.full((1, 2), np.nan),
),
),
(
np.array([[np.nan, np.nan], [np.nan, np.nan]]),
np.array([[[0, 0], [1, 1]], [[1, 1], [2, 2]]]),
0.02,
(np.array([]), [(np.nan, np.nan)], np.empty((0, 2)), np.empty((0, 2))),
(
np.nan,
[(np.nan, np.nan)],
np.full((1, 2), np.nan),
np.full((1, 2), np.nan),
),
),
],
)
Expand Down Expand Up @@ -416,27 +426,27 @@ def test_get_root_widths_invalid_cases():

# Minimum length
result = get_root_widths(np.array([[0, 0]]), np.array([[[0, 0]]]))
assert np.array_equal(result, np.array([]))
assert np.isnan(result)

# Return default values with return_inds=True
result = get_root_widths(np.array([[0, 0]]), np.array([[[0, 0]]]), return_inds=True)
# Checks if both arrays are exactly the same
assert np.array_equal(result[0], np.array([]))
assert np.isnan(result[0])
# Continue to check the other parts of the tuple
assert result[1] == [(np.nan, np.nan)]
# Check the other NumPy arrays in the tuple
assert np.array_equal(result[2], np.empty((0, 2)))
assert np.array_equal(result[3], np.empty((0, 2)))
assert np.all(np.isnan(result[2]))
assert np.all(np.isnan(result[3]))

# All NaNs in input arrays
result = get_root_widths(
np.array([[np.nan, np.nan], [np.nan, np.nan]]),
np.array([[[np.nan, np.nan], [np.nan, np.nan]]]),
)
assert np.array_equal(result, np.array([]))
assert np.isnan(result)

# All lateral roots on the same side
result = get_root_widths(
np.array([[0, 0], [1, 1]]), np.array([[[0, 0], [1, 1]], [[0, 0], [1, 1]]])
)
assert np.array_equal(result, np.array([]))
assert np.isnan(result)
2 changes: 1 addition & 1 deletion tests/test_trait_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_multiple_dicot_pipeline(
all_traits = pipeline.compute_batch_multiple_dicots_traits(series_all)

# Dataframe shape assertions
assert pd.DataFrame(arabidopsis_traits["summary_stats"]).shape == (1, 316)
assert pd.DataFrame([arabidopsis_traits["summary_stats"]]).shape == (1, 315)
assert all_traits.shape == (4, 316)

# Dataframe dtype assertions
Expand Down

0 comments on commit bc12818

Please sign in to comment.