diff --git a/fusilli/data.py b/fusilli/data.py index ad833c6..2a2a505 100644 --- a/fusilli/data.py +++ b/fusilli/data.py @@ -225,16 +225,18 @@ def __init__(self, sources, img_downsample_dims=None): # read in the csv files and raise errors if they don't have the right columns # or if the index column is not named "ID" tab1_df = pd.read_csv(self.tabular1_source) - tab2_df = pd.read_csv(self.tabular2_source) - if "ID" not in tab1_df.columns: raise ValueError("The CSV must have an index column named 'ID'.") if "prediction_label" not in tab1_df.columns: raise ValueError("The CSV must have a label column named 'prediction_label'.") - if "ID" not in tab2_df.columns: - raise ValueError("The CSV must have an index column named 'ID'.") - if "prediction_label" not in tab2_df.columns: - raise ValueError("The CSV must have a label column named 'prediction_label'.") + + # if tabular2_source exists, check it has the right columns + if self.tabular2_source != "": + tab2_df = pd.read_csv(self.tabular2_source) + if "ID" not in tab2_df.columns: + raise ValueError("The CSV must have an index column named 'ID'.") + if "prediction_label" not in tab2_df.columns: + raise ValueError("The CSV must have a label column named 'prediction_label'.") def load_tabular1(self): """