diff --git a/spras/analysis/ml.py b/spras/analysis/ml.py index b82e845b..3dad8775 100644 --- a/spras/analysis/ml.py +++ b/spras/analysis/ml.py @@ -88,9 +88,11 @@ def summarize_networks(file_paths: Iterable[Union[str, PathLike]]) -> pd.DataFra return concated_df + def validate_df(dataframe: pd.DataFrame): """ Raises an error if the dataframe is empty or contains one pathway (one row) + @param dataframe: datafrom of pathways to validate """ if dataframe.empty: raise ValueError("ML post-processing cannot proceed because the summarize network dataframe is empty.\nWe " @@ -100,6 +102,7 @@ def validate_df(dataframe: pd.DataFrame): f"The ml post-processing requires more than one pathway, but currently " f"there are only {min(dataframe.shape)} pathways.") + def create_palette(column_names): """ Generates a dictionary mapping each column name (algorithm name) diff --git a/spras/config.py b/spras/config.py index e3bdde77..14f1a926 100644 --- a/spras/config.py +++ b/spras/config.py @@ -80,11 +80,13 @@ def __init__(self, raw_config): # Only includes algorithms that are set to be run with 'include: true'. self.algorithm_params = None # Deprecated. Previously a dict mapping algorithm names to a Boolean tracking whether they used directed graphs. - self.algorithm_directed = None + self.algorithm_directed = None # A dict with the analysis settings self.analysis_params = None # A dict with the ML settings self.ml_params = None + # A Boolean specifying whether to run ML analysis for individual algorithms + self.analysis_include_ml_aggregate_algo = None # A dict with the PCA settings self.pca_params = None # A dict with the hierarchical clustering settings @@ -254,7 +256,7 @@ def process_config(self, raw_config): raise ValueError("Evaluation analysis cannot run as gold standard data not provided. " "Please set evaluation include to false or provide gold standard data.") - if 'aggregate_per_algorithm' in self.ml_params and self.analysis_include_ml == True: + if 'aggregate_per_algorithm' in self.ml_params and self.analysis_include_ml: self.analysis_include_ml_aggregate_algo = raw_config["analysis"]["ml"]["aggregate_per_algorithm"] else: self.analysis_include_ml_aggregate_algo = False diff --git a/test/ml/test_ml.py b/test/ml/test_ml.py index 020bcce8..2b5720ae 100644 --- a/test/ml/test_ml.py +++ b/test/ml/test_ml.py @@ -97,7 +97,6 @@ def test_pca_robustness(self): assert coord.equals(expected) - def test_hac_horizontal(self): dataframe = ml.summarize_networks([INPUT_DIR + 'test-data-s1/s1.txt', INPUT_DIR + 'test-data-s2/s2.txt', INPUT_DIR + 'test-data-s3/s3.txt']) ml.hac_horizontal(dataframe, OUT_DIR + 'hac-horizontal.png', OUT_DIR + 'hac-clusters-horizontal.txt') @@ -138,6 +137,5 @@ def test_ensemble_network_empty(self): en = pd.read_table(OUT_DIR + 'ensemble-network-empty.tsv') expected = pd.read_table(EXPECT_DIR + 'expected-ensemble-network-empty.tsv') - expected = expected.round(5) assert en.equals(expected)