Skip to content

Commit

Permalink
fix tests for pipelines to take in slp paths
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed May 13, 2024
1 parent f686a6e commit 892d12d
Showing 1 changed file with 69 additions and 35 deletions.
104 changes: 69 additions & 35 deletions tests/test_trait_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,27 @@
)


def test_dicot_pipeline(canola_h5, soy_h5):
def test_dicot_pipeline(
canola_h5,
soy_h5,
canola_primary_slp,
canola_lateral_slp,
soy_primary_slp,
soy_lateral_slp,
):
# Load the data
canola = Series.load(canola_h5, primary_name="primary", lateral_name="lateral")
soy = Series.load(soy_h5, primary_name="primary", lateral_name="lateral")
canola = Series.load(
series_name="canola",
h5_path=canola_h5,
primary_path=canola_primary_slp,
lateral_path=canola_lateral_slp,
)
soy = Series.load(
series_name="soy",
h5_path=soy_h5,
primary_path=soy_primary_slp,
lateral_path=soy_lateral_slp,
)

pipeline = DicotPipeline()
canola_traits = pipeline.compute_plant_traits(canola)
Expand All @@ -30,26 +47,36 @@ def test_dicot_pipeline(canola_h5, soy_h5):
assert all_traits.shape == (2, 1036)


def test_OlderMonocot_pipeline(rice_main_10do_h5):
rice = Series.load(rice_main_10do_h5, crown_name="crown")
def test_OlderMonocot_pipeline(rice_main_10do_h5, rice_main_10do_slp):
rice = Series.load(
series_name="rice_10do",
h5_path=rice_main_10do_h5,
crown_path=rice_main_10do_slp,
)

pipeline = OlderMonocotPipeline()
rice_10dag_traits = pipeline.compute_plant_traits(rice)

assert rice_10dag_traits.shape == (72, 102)


def test_younger_monocot_pipeline(rice_h5, rice_folder):
rice = Series.load(rice_h5, primary_name="primary", crown_name="crown")
rice_series_all = find_all_series(rice_folder)
series_all = [
Series.load(series, primary_name="primary", crown_name="crown")
for series in rice_series_all
]

def test_younger_monocot_pipeline(rice_pipeline_output_folder):
# Find slp paths in folder
slp_paths = find_all_slp_paths(rice_pipeline_output_folder)
assert len(slp_paths) == 4
# Load series from slps
rice_series_all = load_series_from_slps(
slp_paths=slp_paths, h5s=False, csv_path=None
)
assert len(rice_series_all) == 2
# Get first series
rice_series = rice_series_all[0]
# Initialize pipeline for younger monocot
pipeline = YoungerMonocotPipeline()
rice_traits = pipeline.compute_plant_traits(rice)
all_traits = pipeline.compute_batch_traits(series_all)
# Get traits for the first series using the pipeline
rice_traits = pipeline.compute_plant_traits(rice_series)
# Get all traits for all series using the pipeline
all_traits = pipeline.compute_batch_traits(rice_series_all)

# Dataframe shape assertions
assert rice_traits.shape == (72, 104)
Expand Down Expand Up @@ -96,14 +123,22 @@ def test_younger_monocot_pipeline(rice_h5, rice_folder):
).all(), "angle_column in all_traits contains values out of range [0, 180]"


def test_older_monocot_pipeline(rice_main_10do_h5, rice_10do_folder):
rice = Series.load(rice_main_10do_h5, crown_name="crown")
rice_series_all = find_all_series(rice_10do_folder)
series_all = [Series.load(series, crown_name="crown") for series in rice_series_all]
def test_older_monocot_pipeline(rice_10do_pipeline_output_folder):
# Find slp paths in folder
slp_paths = find_all_slp_paths(rice_10do_pipeline_output_folder)
assert len(slp_paths) == 1
# Load series from slps
rice_series_all = load_series_from_slps(
slp_paths=slp_paths, h5s=False, csv_path=None
)
assert len(rice_series_all) == 1
# Get first series
rice_series = rice_series_all[0]
assert rice_series.series_name == "scan0K9E8BI"

pipeline = OlderMonocotPipeline()
rice_traits = pipeline.compute_plant_traits(rice)
all_traits = pipeline.compute_batch_traits(series_all)
all_traits = pipeline.compute_batch_traits(rice_series_all)
rice_traits = pipeline.compute_plant_traits(rice_series)

# Dataframe shape assertions
assert rice_traits.shape == (72, 102)
Expand Down Expand Up @@ -148,27 +183,26 @@ def test_multiple_dicot_pipeline(
multiple_arabidopsis_11do_h5,
multiple_arabidopsis_11do_folder,
multiple_arabidopsis_11do_csv,
multiple_arabidopsis_11do_primary_slp,
multiple_arabidopsis_11do_lateral_slp,
):
arabidopsis = Series.load(
multiple_arabidopsis_11do_h5,
primary_name="primary",
lateral_name="lateral",
series_name="997_1",
h5_path=multiple_arabidopsis_11do_h5,
primary_path=multiple_arabidopsis_11do_primary_slp,
lateral_path=multiple_arabidopsis_11do_lateral_slp,
csv_path=multiple_arabidopsis_11do_csv,
)
arabidopsis_slp_paths = find_all_slp_paths(multiple_arabidopsis_11do_folder)
arabidopsis_series_all = load_series_from_slps(
slp_paths=arabidopsis_slp_paths,
h5s=True,
csv_path=multiple_arabidopsis_11do_csv,
)
arabidopsis_series_all = find_all_series(multiple_arabidopsis_11do_folder)
series_all = [
Series.load(
series,
primary_name="primary",
lateral_name="lateral",
csv_path=multiple_arabidopsis_11do_csv,
)
for series in arabidopsis_series_all
]

pipeline = MultipleDicotPipeline()
arabidopsis_traits = pipeline.compute_multiple_dicots_traits(arabidopsis)
all_traits = pipeline.compute_batch_multiple_dicots_traits(series_all)
all_traits = pipeline.compute_batch_multiple_dicots_traits(arabidopsis_series_all)

# Dataframe shape assertions
assert pd.DataFrame([arabidopsis_traits["summary_stats"]]).shape == (1, 315)
Expand Down

0 comments on commit 892d12d

Please sign in to comment.