diff --git a/cybench/config.py b/cybench/config.py index af27ff04..4e934ddf 100644 --- a/cybench/config.py +++ b/cybench/config.py @@ -11,6 +11,13 @@ PATH_DATA_DIR = os.path.join(CONFIG_DIR, "data") os.makedirs(PATH_DATA_DIR, exist_ok=True) +# Path to folder where aligned data is stored +# NOTE: Data is saved after aligning time series to crop season. +# Similarly, labels data is aligned to have the same locations and years +# as in input data. +PATH_ALIGNED_DATA_DIR = os.path.join(CONFIG_DIR, "aligned_data") +os.makedirs(PATH_ALIGNED_DATA_DIR, exist_ok=True) + # Path to folder where output is stored PATH_OUTPUT_DIR = os.path.join(CONFIG_DIR, "output") os.makedirs(PATH_OUTPUT_DIR, exist_ok=True) diff --git a/cybench/datasets/alignment.py b/cybench/datasets/alignment.py index 1f3c1a08..f5d297da 100644 --- a/cybench/datasets/alignment.py +++ b/cybench/datasets/alignment.py @@ -125,7 +125,7 @@ def trim_to_lead_time(df, crop_cal_df, lead_time, spinup_days=90): return df -def align_data(df_y: pd.DataFrame, dfs_x: tuple) -> tuple: +def align_data(df_y: pd.DataFrame, dfs_x: dict) -> tuple: # Data Alignment # - Filter the label data based on presence within all feature data sets # - Filter feature data based on label data @@ -133,7 +133,7 @@ def align_data(df_y: pd.DataFrame, dfs_x: tuple) -> tuple: # Filter label data index_y_selection = set(df_y.index.values) - for df_x in dfs_x: + for df_x in dfs_x.values(): if len(df_x.index.names) == 1: index_y_selection = { (loc_id, year) diff --git a/cybench/datasets/configured.py b/cybench/datasets/configured.py index 60d035c0..c10a1e66 100644 --- a/cybench/datasets/configured.py +++ b/cybench/datasets/configured.py @@ -1,10 +1,10 @@ import os import pandas as pd -from datetime import date, timedelta from cybench.config import ( PATH_DATA_DIR, + PATH_ALIGNED_DATA_DIR, DATASETS, KEY_LOC, KEY_YEAR, @@ -52,6 +52,7 @@ def load_dfs( df_y = df_y[[KEY_LOC, KEY_YEAR, KEY_TARGET]] df_y = df_y.dropna(axis=0) df_y = df_y[df_y[KEY_TARGET] > 0.0] + df_y.set_index([KEY_LOC, KEY_YEAR], inplace=True) # soil df_x_soil = pd.read_csv( @@ -59,6 +60,7 @@ def load_dfs( header=0, ) df_x_soil = df_x_soil[[KEY_LOC] + SOIL_PROPERTIES] + df_x_soil.set_index([KEY_LOC], inplace=True) # crop calendar df_crop_cal = pd.read_csv( @@ -80,7 +82,7 @@ def load_dfs( df_x_meteo = _preprocess_time_series_data( df_x_meteo, ts_index_cols, METEO_INDICATORS, df_crop_cal, lead_time ) - df_x_meteo = df_x_meteo.set_index(ts_index_cols) + df_x_meteo.set_index(ts_index_cols, inplace=True) # fpar df_x_fpar = pd.read_csv( @@ -90,7 +92,7 @@ def load_dfs( df_x_fpar = _preprocess_time_series_data( df_x_fpar, ts_index_cols, [RS_FPAR], df_crop_cal, lead_time ) - df_x_fpar = df_x_fpar.set_index(ts_index_cols) + df_x_fpar.set_index(ts_index_cols, inplace=True) # ndvi df_x_ndvi = pd.read_csv( @@ -100,7 +102,7 @@ def load_dfs( df_x_ndvi = _preprocess_time_series_data( df_x_ndvi, ts_index_cols, [RS_NDVI], df_crop_cal, lead_time ) - df_x_ndvi = df_x_ndvi.set_index(ts_index_cols) + df_x_ndvi.set_index(ts_index_cols, inplace=True) # soil moisture df_x_soil_moisture = pd.read_csv( @@ -116,54 +118,129 @@ def load_dfs( df_crop_cal, lead_time, ) - df_x_soil_moisture = df_x_soil_moisture.set_index(ts_index_cols) + df_x_soil_moisture.set_index(ts_index_cols, inplace=True) - df_y = df_y.set_index([KEY_LOC, KEY_YEAR]) - df_x_soil = df_x_soil.set_index([KEY_LOC]) - dfs_x = (df_x_soil, df_x_meteo, df_x_fpar, df_x_ndvi, df_x_soil_moisture) + dfs_x = { + "soil": df_x_soil, + "meteo": df_x_meteo, + RS_FPAR: df_x_fpar, + RS_NDVI: df_x_ndvi, + "soil_moisture": df_x_soil_moisture, + } df_y, dfs_x = align_data(df_y, dfs_x) + return df_y, dfs_x + + +def load_aligned_dfs(crop: str, country_code: str) -> tuple: + path_data_cn = os.path.join(PATH_ALIGNED_DATA_DIR, crop, country_code) + # targets + df_y = pd.read_csv( + os.path.join(path_data_cn, "_".join(["yield", crop, country_code]) + ".csv"), + header=0, + index_col=[KEY_LOC, KEY_YEAR], + ) + + # soil + df_x_soil = pd.read_csv( + os.path.join(path_data_cn, "_".join(["soil", crop, country_code]) + ".csv"), + header=0, + index_col=[KEY_LOC], + ) + + # Time series data + ts_index_cols = [KEY_LOC, KEY_YEAR, "date"] + # meteo + df_x_meteo = pd.read_csv( + os.path.join(path_data_cn, "_".join(["meteo", crop, country_code]) + ".csv"), + header=0, + index_col=ts_index_cols, + ) + + # fpar + df_x_fpar = pd.read_csv( + os.path.join(path_data_cn, "_".join([RS_FPAR, crop, country_code]) + ".csv"), + header=0, + index_col=ts_index_cols, + ) + + # ndvi + df_x_ndvi = pd.read_csv( + os.path.join(path_data_cn, "_".join([RS_NDVI, crop, country_code]) + ".csv"), + header=0, + index_col=ts_index_cols, + ) + + # soil moisture + df_x_soil_moisture = pd.read_csv( + os.path.join( + path_data_cn, "_".join(["soil_moisture", crop, country_code]) + ".csv" + ), + header=0, + index_col=ts_index_cols, + ) + + dfs_x = { + "soil": df_x_soil, + "meteo": df_x_meteo, + RS_FPAR: df_x_fpar, + RS_NDVI: df_x_ndvi, + "soil_moisture": df_x_soil_moisture, + } return df_y, dfs_x -def load_dfs_crop(crop: str, countries: list = None) -> tuple: +def load_dfs_crop(crop: str, countries: list = None) -> dict: assert crop in DATASETS if countries is None: countries = DATASETS[crop] df_y = pd.DataFrame() - dfs_x = tuple() + dfs_x = {} for cn in countries: - if not os.path.exists(os.path.join(PATH_DATA_DIR, crop, cn)): + # load aligned data if exists + if os.path.exists(os.path.join(PATH_ALIGNED_DATA_DIR, crop, cn)): + df_y_cn, dfs_x_cn = load_aligned_dfs(crop, cn) + elif os.path.exists(os.path.join(PATH_DATA_DIR, crop, cn)): + df_y_cn, dfs_x_cn = load_dfs(crop, cn) + # save aligned data + cn_data_dir = os.path.join(PATH_ALIGNED_DATA_DIR, crop, cn) + os.makedirs(cn_data_dir, exist_ok=True) + df_y_cn.to_csv( + os.path.join(cn_data_dir, "_".join([KEY_TARGET, crop, cn]) + ".csv"), + ) + for x, df_x in dfs_x_cn.items(): + df_x.to_csv( + os.path.join(cn_data_dir, "_".join([x, crop, cn]) + ".csv"), + ) + else: continue - df_y_cn, dfs_x_cn = load_dfs(crop, cn) df_y = pd.concat([df_y, df_y_cn], axis=0) if len(dfs_x) == 0: dfs_x = dfs_x_cn else: - dfs_x = tuple( - pd.concat([df_x, df_x_cn], axis=0) - for df_x, df_x_cn in zip(dfs_x, dfs_x_cn) - ) + for x, df_x_cn in dfs_x_cn.items(): + dfs_x[x] = pd.concat([dfs_x[x], df_x_cn], axis=0) - new_dfs_x = tuple() # keep the same number of time steps for time series data # NOTE: At this point, each df_x contains data for all selected countries. - for df_x in dfs_x: - # If index is [KEY_LOC, KEY_YEAR, "date"] - if "date" in df_x.index.names: - index_names = df_x.index.names - column_names = list(df_x.columns) - df_x.reset_index(inplace=True) - min_time_steps = df_x.groupby([KEY_LOC, KEY_YEAR])["date"].count().min() - df_x = df_x.sort_values(by=[KEY_LOC, KEY_YEAR, "date"]) - df_x = df_x.groupby([KEY_LOC, KEY_YEAR]).tail(min_time_steps).reset_index() - df_x.set_index(index_names, inplace=True) - df_x = df_x[column_names] - - new_dfs_x += (df_x,) - - return df_y, new_dfs_x + if len(countries) > 1: + for x in dfs_x: + df_x = dfs_x[x] + # If index is [KEY_LOC, KEY_YEAR, "date"] + if "date" in df_x.index.names: + index_names = df_x.index.names + column_names = list(df_x.columns) + df_x.reset_index(inplace=True) + min_time_steps = df_x.groupby([KEY_LOC, KEY_YEAR])["date"].count().min() + df_x = df_x.sort_values(by=[KEY_LOC, KEY_YEAR, "date"]) + df_x = df_x.groupby([KEY_LOC, KEY_YEAR]).tail(min_time_steps).reset_index() + df_x.set_index(index_names, inplace=True) + df_x = df_x[column_names] + + dfs_x[x] = df_x + + return df_y, dfs_x diff --git a/cybench/datasets/dataset.py b/cybench/datasets/dataset.py index 23efb62a..a7b21fdf 100644 --- a/cybench/datasets/dataset.py +++ b/cybench/datasets/dataset.py @@ -91,6 +91,7 @@ def load(dataset_name: str) -> "Dataset": crop = crop_countries[0] assert crop in DATASETS, Exception(f'Unrecognized crop name "{crop}"') + # only crop is specified if len(crop_countries) < 2: country_codes = DATASETS[crop] else: @@ -105,7 +106,7 @@ def load(dataset_name: str) -> "Dataset": return Dataset( crop, df_y, - list(dfs_x), + list(dfs_x.values()), ) @property diff --git a/cybench/datasets/transforms.py b/cybench/datasets/transforms.py index 689c5646..ba4168dd 100644 --- a/cybench/datasets/transforms.py +++ b/cybench/datasets/transforms.py @@ -7,15 +7,14 @@ def transform_ts_inputs_to_dekadal(batch, min_date, max_date): - min_dekad = dekad_from_date(min_date) - max_dekad = dekad_from_date(max_date) + min_dekad = dekad_from_date(str(min_date)) + max_dekad = dekad_from_date(str(max_date)) dekads = list(range(0, max_dekad - min_dekad + 1)) for key in TIME_SERIES_PREDICTORS: value = batch[key] # Transform dates to dekads - date_strs = [str(date) for date in batch[KEY_DATES][key]] value_dekads = torch.tensor( - [dekad_from_date(date) for date in date_strs], device=value.device + [dekad_from_date(str(date)) for date in batch[KEY_DATES][key]], device=value.device ) value_dekads -= 1 diff --git a/tests/datasets/test_configured.py b/tests/datasets/test_configured.py index 196deeda..3410c6d2 100644 --- a/tests/datasets/test_configured.py +++ b/tests/datasets/test_configured.py @@ -6,11 +6,11 @@ def test_load_dfs_crop(): # Sort indices for fast lookup df_y.sort_index(inplace=True) - for df_x in dfs_x: - df_x.sort_index(inplace=True) + for x in dfs_x: + dfs_x[x] = dfs_x[x].sort_index() for i, row in df_y.iterrows(): - for df_x in dfs_x: + for df_x in dfs_x.values(): if len(df_x.index.names) == 1: assert i[0] in df_x.index else: