Skip to content

Commit

Permalink
#257 save aligned data
Browse files Browse the repository at this point in the history
  • Loading branch information
krsnapaudel committed Aug 30, 2024
1 parent 3e9389f commit 6abbdd6
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 42 deletions.
7 changes: 7 additions & 0 deletions cybench/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
PATH_DATA_DIR = os.path.join(CONFIG_DIR, "data")
os.makedirs(PATH_DATA_DIR, exist_ok=True)

# Path to folder where aligned data is stored
# NOTE: Data is saved after aligning time series to crop season.
# Similarly, labels data is aligned to have the same locations and years
# as in input data.
PATH_ALIGNED_DATA_DIR = os.path.join(CONFIG_DIR, "aligned_data")
os.makedirs(PATH_ALIGNED_DATA_DIR, exist_ok=True)

# Path to folder where output is stored
PATH_OUTPUT_DIR = os.path.join(CONFIG_DIR, "output")
os.makedirs(PATH_OUTPUT_DIR, exist_ok=True)
Expand Down
4 changes: 2 additions & 2 deletions cybench/datasets/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ def trim_to_lead_time(df, crop_cal_df, lead_time, spinup_days=90):
return df


def align_data(df_y: pd.DataFrame, dfs_x: tuple) -> tuple:
def align_data(df_y: pd.DataFrame, dfs_x: dict) -> tuple:
# Data Alignment
# - Filter the label data based on presence within all feature data sets
# - Filter feature data based on label data

# Filter label data

index_y_selection = set(df_y.index.values)
for df_x in dfs_x:
for df_x in dfs_x.values():
if len(df_x.index.names) == 1:
index_y_selection = {
(loc_id, year)
Expand Down
141 changes: 109 additions & 32 deletions cybench/datasets/configured.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import pandas as pd
from datetime import date, timedelta


from cybench.config import (
PATH_DATA_DIR,
PATH_ALIGNED_DATA_DIR,
DATASETS,
KEY_LOC,
KEY_YEAR,
Expand Down Expand Up @@ -52,13 +52,15 @@ def load_dfs(
df_y = df_y[[KEY_LOC, KEY_YEAR, KEY_TARGET]]
df_y = df_y.dropna(axis=0)
df_y = df_y[df_y[KEY_TARGET] > 0.0]
df_y.set_index([KEY_LOC, KEY_YEAR], inplace=True)

# soil
df_x_soil = pd.read_csv(
os.path.join(path_data_cn, "_".join(["soil", crop, country_code]) + ".csv"),
header=0,
)
df_x_soil = df_x_soil[[KEY_LOC] + SOIL_PROPERTIES]
df_x_soil.set_index([KEY_LOC], inplace=True)

# crop calendar
df_crop_cal = pd.read_csv(
Expand All @@ -80,7 +82,7 @@ def load_dfs(
df_x_meteo = _preprocess_time_series_data(
df_x_meteo, ts_index_cols, METEO_INDICATORS, df_crop_cal, lead_time
)
df_x_meteo = df_x_meteo.set_index(ts_index_cols)
df_x_meteo.set_index(ts_index_cols, inplace=True)

# fpar
df_x_fpar = pd.read_csv(
Expand All @@ -90,7 +92,7 @@ def load_dfs(
df_x_fpar = _preprocess_time_series_data(
df_x_fpar, ts_index_cols, [RS_FPAR], df_crop_cal, lead_time
)
df_x_fpar = df_x_fpar.set_index(ts_index_cols)
df_x_fpar.set_index(ts_index_cols, inplace=True)

# ndvi
df_x_ndvi = pd.read_csv(
Expand All @@ -100,7 +102,7 @@ def load_dfs(
df_x_ndvi = _preprocess_time_series_data(
df_x_ndvi, ts_index_cols, [RS_NDVI], df_crop_cal, lead_time
)
df_x_ndvi = df_x_ndvi.set_index(ts_index_cols)
df_x_ndvi.set_index(ts_index_cols, inplace=True)

# soil moisture
df_x_soil_moisture = pd.read_csv(
Expand All @@ -116,54 +118,129 @@ def load_dfs(
df_crop_cal,
lead_time,
)
df_x_soil_moisture = df_x_soil_moisture.set_index(ts_index_cols)
df_x_soil_moisture.set_index(ts_index_cols, inplace=True)

df_y = df_y.set_index([KEY_LOC, KEY_YEAR])
df_x_soil = df_x_soil.set_index([KEY_LOC])
dfs_x = (df_x_soil, df_x_meteo, df_x_fpar, df_x_ndvi, df_x_soil_moisture)
dfs_x = {
"soil": df_x_soil,
"meteo": df_x_meteo,
RS_FPAR: df_x_fpar,
RS_NDVI: df_x_ndvi,
"soil_moisture": df_x_soil_moisture,
}

df_y, dfs_x = align_data(df_y, dfs_x)
return df_y, dfs_x


def load_aligned_dfs(crop: str, country_code: str) -> tuple:
path_data_cn = os.path.join(PATH_ALIGNED_DATA_DIR, crop, country_code)
# targets
df_y = pd.read_csv(
os.path.join(path_data_cn, "_".join(["yield", crop, country_code]) + ".csv"),
header=0,
index_col=[KEY_LOC, KEY_YEAR],
)

# soil
df_x_soil = pd.read_csv(
os.path.join(path_data_cn, "_".join(["soil", crop, country_code]) + ".csv"),
header=0,
index_col=[KEY_LOC],
)

# Time series data
ts_index_cols = [KEY_LOC, KEY_YEAR, "date"]
# meteo
df_x_meteo = pd.read_csv(
os.path.join(path_data_cn, "_".join(["meteo", crop, country_code]) + ".csv"),
header=0,
index_col=ts_index_cols,
)

# fpar
df_x_fpar = pd.read_csv(
os.path.join(path_data_cn, "_".join([RS_FPAR, crop, country_code]) + ".csv"),
header=0,
index_col=ts_index_cols,
)

# ndvi
df_x_ndvi = pd.read_csv(
os.path.join(path_data_cn, "_".join([RS_NDVI, crop, country_code]) + ".csv"),
header=0,
index_col=ts_index_cols,
)

# soil moisture
df_x_soil_moisture = pd.read_csv(
os.path.join(
path_data_cn, "_".join(["soil_moisture", crop, country_code]) + ".csv"
),
header=0,
index_col=ts_index_cols,
)

dfs_x = {
"soil": df_x_soil,
"meteo": df_x_meteo,
RS_FPAR: df_x_fpar,
RS_NDVI: df_x_ndvi,
"soil_moisture": df_x_soil_moisture,
}

return df_y, dfs_x


def load_dfs_crop(crop: str, countries: list = None) -> tuple:
def load_dfs_crop(crop: str, countries: list = None) -> dict:
assert crop in DATASETS

if countries is None:
countries = DATASETS[crop]

df_y = pd.DataFrame()
dfs_x = tuple()
dfs_x = {}
for cn in countries:
if not os.path.exists(os.path.join(PATH_DATA_DIR, crop, cn)):
# load aligned data if exists
if os.path.exists(os.path.join(PATH_ALIGNED_DATA_DIR, crop, cn)):
df_y_cn, dfs_x_cn = load_aligned_dfs(crop, cn)
elif os.path.exists(os.path.join(PATH_DATA_DIR, crop, cn)):
df_y_cn, dfs_x_cn = load_dfs(crop, cn)
# save aligned data
cn_data_dir = os.path.join(PATH_ALIGNED_DATA_DIR, crop, cn)
os.makedirs(cn_data_dir, exist_ok=True)
df_y_cn.to_csv(
os.path.join(cn_data_dir, "_".join([KEY_TARGET, crop, cn]) + ".csv"),
)
for x, df_x in dfs_x_cn.items():
df_x.to_csv(
os.path.join(cn_data_dir, "_".join([x, crop, cn]) + ".csv"),
)
else:
continue

df_y_cn, dfs_x_cn = load_dfs(crop, cn)
df_y = pd.concat([df_y, df_y_cn], axis=0)
if len(dfs_x) == 0:
dfs_x = dfs_x_cn
else:
dfs_x = tuple(
pd.concat([df_x, df_x_cn], axis=0)
for df_x, df_x_cn in zip(dfs_x, dfs_x_cn)
)
for x, df_x_cn in dfs_x_cn.items():
dfs_x[x] = pd.concat([dfs_x[x], df_x_cn], axis=0)

new_dfs_x = tuple()
# keep the same number of time steps for time series data
# NOTE: At this point, each df_x contains data for all selected countries.
for df_x in dfs_x:
# If index is [KEY_LOC, KEY_YEAR, "date"]
if "date" in df_x.index.names:
index_names = df_x.index.names
column_names = list(df_x.columns)
df_x.reset_index(inplace=True)
min_time_steps = df_x.groupby([KEY_LOC, KEY_YEAR])["date"].count().min()
df_x = df_x.sort_values(by=[KEY_LOC, KEY_YEAR, "date"])
df_x = df_x.groupby([KEY_LOC, KEY_YEAR]).tail(min_time_steps).reset_index()
df_x.set_index(index_names, inplace=True)
df_x = df_x[column_names]

new_dfs_x += (df_x,)

return df_y, new_dfs_x
if len(countries) > 1:
for x in dfs_x:
df_x = dfs_x[x]
# If index is [KEY_LOC, KEY_YEAR, "date"]
if "date" in df_x.index.names:
index_names = df_x.index.names
column_names = list(df_x.columns)
df_x.reset_index(inplace=True)
min_time_steps = df_x.groupby([KEY_LOC, KEY_YEAR])["date"].count().min()
df_x = df_x.sort_values(by=[KEY_LOC, KEY_YEAR, "date"])
df_x = df_x.groupby([KEY_LOC, KEY_YEAR]).tail(min_time_steps).reset_index()
df_x.set_index(index_names, inplace=True)
df_x = df_x[column_names]

dfs_x[x] = df_x

return df_y, dfs_x
3 changes: 2 additions & 1 deletion cybench/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def load(dataset_name: str) -> "Dataset":
crop = crop_countries[0]
assert crop in DATASETS, Exception(f'Unrecognized crop name "{crop}"')

# only crop is specified
if len(crop_countries) < 2:
country_codes = DATASETS[crop]
else:
Expand All @@ -105,7 +106,7 @@ def load(dataset_name: str) -> "Dataset":
return Dataset(
crop,
df_y,
list(dfs_x),
list(dfs_x.values()),
)

@property
Expand Down
7 changes: 3 additions & 4 deletions cybench/datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@


def transform_ts_inputs_to_dekadal(batch, min_date, max_date):
min_dekad = dekad_from_date(min_date)
max_dekad = dekad_from_date(max_date)
min_dekad = dekad_from_date(str(min_date))
max_dekad = dekad_from_date(str(max_date))
dekads = list(range(0, max_dekad - min_dekad + 1))
for key in TIME_SERIES_PREDICTORS:
value = batch[key]
# Transform dates to dekads
date_strs = [str(date) for date in batch[KEY_DATES][key]]
value_dekads = torch.tensor(
[dekad_from_date(date) for date in date_strs], device=value.device
[dekad_from_date(str(date)) for date in batch[KEY_DATES][key]], device=value.device
)
value_dekads -= 1

Expand Down
6 changes: 3 additions & 3 deletions tests/datasets/test_configured.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ def test_load_dfs_crop():

# Sort indices for fast lookup
df_y.sort_index(inplace=True)
for df_x in dfs_x:
df_x.sort_index(inplace=True)
for x in dfs_x:
dfs_x[x] = dfs_x[x].sort_index()

for i, row in df_y.iterrows():
for df_x in dfs_x:
for df_x in dfs_x.values():
if len(df_x.index.names) == 1:
assert i[0] in df_x.index
else:
Expand Down

0 comments on commit 6abbdd6

Please sign in to comment.