Skip to content

Commit

Permalink
cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
teyaberg committed May 31, 2024
1 parent c8f26ea commit db18dc5
Showing 1 changed file with 101 additions and 89 deletions.
190 changes: 101 additions & 89 deletions scripts/xgboost_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, cfg: DictConfig, split: str = "train"):
self.static_data_path = self.data_path / "static" / split
self._data_shards = [shard.stem for shard in list(self.static_data_path.glob("*.parquet"))]
if cfg.iterator.keep_static_data_in_memory:
self._static_shards = self._get_static_shards()
self._static_shards = self._collect_static_shards_in_memory()

self.codes_set, self.aggs_set, self.min_frequency_set, self.window_set = self._get_inclusion_sets()

Expand All @@ -46,32 +46,26 @@ def _get_inclusion_sets(self) -> tuple[set, set, set]:
codes_set = None
aggs_set = None
min_frequency_set = None
window_set = None

if self.cfg.codes is not None:
codes_set = set(self.cfg.codes)

if self.cfg.aggs is not None:
aggs_set = set(self.cfg.aggs)

if self.cfg.min_code_inclusion_frequency is not None:
with open(self.data_path / "feature_freqs.json") as f:
feature_freqs = json.load(f)
min_frequency_set = {
key for key, value in feature_freqs.items() if value >= self.cfg.min_code_inclusion_frequency
}
window_set = set(self.cfg.window_sizes)
if self.cfg.window_sizes is not None:
window_set = set(self.cfg.window_sizes)

return codes_set, aggs_set, min_frequency_set, window_set

def _load_static_shard_by_index(self, idx: int) -> sp.csc_matrix:
"""Load a static shard into memory.
Args:
- idx (int): Index of the shard to load.
Returns:
- sp.csc_matrix: Sparse matrix with the static shard.
"""
return pd.read_parquet(self.static_data_path / f"{self._data_shards[int(idx)]}.parquet")

def _get_static_shards(self) -> dict:
def _collect_static_shards_in_memory(self) -> dict:
"""Load static shards into memory.
Returns:
Expand All @@ -82,59 +76,16 @@ def _get_static_shards(self) -> dict:
static_shards[iter] = self._load_static_shard_by_index(iter)
return static_shards

def _sparsify_shard(self, df: pd.DataFrame) -> tuple[sp.csc_matrix, np.ndarray]:
"""Make X and y as scipy sparse arrays for XGBoost.
Args:
- df (pandas.DataFrame): Data frame to sparsify.
Returns:
- tuple[scipy.sparse.csr_matrix, numpy.ndarray]: Tuple of feature data and labels.
"""
labels = df.loc[:, [col for col in df.columns if col.endswith("/task")]]
data = df.drop(columns=labels.columns)
for col in data.columns:
if not isinstance(data[col].dtype, pd.SparseDtype):
data[col] = pd.arrays.SparseArray(data[col])
sparse_matrix = data.sparse.to_coo()
return csr_matrix(sparse_matrix), labels.values

def _validate_shard_file_inclusion(self, file: Path) -> bool:
parts = file.relative_to(self.dynamic_data_path).parts
if not parts:
return False

windows_part = parts[0]
aggs_part = "/".join(parts[1:-1])

return (self.window_set is None or windows_part in self.window_set) and (
self.aggs_set is None or aggs_part in self.aggs_set
)

def _assert_correct_sorting(self, shard: pd.DataFrame):
"""Assert that the shard is sorted correctly."""
if "timestamp" in shard.columns:
sort_columns = ["patient_id", "timestamp"]
else:
sort_columns = ["patient_id"]
assert shard[sort_columns].equals(shard[sort_columns].sort_values(by=sort_columns)), (
"Shard is not sorted on correctly. "
"Please ensure that the data is sorted on patient_id and timestamp, if applicable."
)

def _get_sparse_dynamic_shard_from_file(self, path: Path) -> pd.DataFrame:
"""Load a sparse shard into memory. This returns a shard as a pandas dataframe, asserted that it is
sorted on patient id and timestamp, if included.
def _load_static_shard_by_index(self, idx: int) -> sp.csc_matrix:
"""Load a static shard into memory.
Args:
- path (Path): Path to the sparse shard.
- idx (int): Index of the shard to load.
Returns:
- pd.DataFrame: Data frame with the sparse shard.
- sp.csc_matrix: Sparse matrix with the static shard.
"""
shard = pd.read_pickle(path)
self._assert_correct_sorting(shard)
return shard.drop(columns=["patient_id", "timestamp"])
return pd.read_parquet(self.static_data_path / f"{self._data_shards[int(idx)]}.parquet")

def _get_static_shard_by_index(self, idx: int) -> pd.DataFrame:
"""Get the static shard from memory or disk.
Expand All @@ -150,7 +101,7 @@ def _get_static_shard_by_index(self, idx: int) -> pd.DataFrame:
else:
return self._load_static_shard_by_index(self._data_shards[idx])

def _get_task(self, idx: int) -> pd.DataFrame:
def _get_task_by_index(self, idx: int) -> pd.DataFrame:
"""Get the task data for a specific shard.
Args:
Expand All @@ -165,29 +116,32 @@ def _get_task(self, idx: int) -> pd.DataFrame:
shard["label"] = np.random.randint(0, 2, shard.shape[0])
return shard[["patient_id", "timestamp", "label"]]

def _filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
"""Filter the dynamic data frame based on the inclusion sets.
def _load_dynamic_shard_from_file(self, path: Path) -> pd.DataFrame:
"""Load a sparse shard into memory. This returns a shard as a pandas dataframe, asserted that it is
sorted on patient id and timestamp, if included.
Args:
- df (pd.DataFrame): Data frame to filter.
- path (Path): Path to the sparse shard.
Returns:
- pd.DataFrame: Filtered data frame.
- pd.DataFrame: Data frame with the sparse shard.
"""
code_parts = ["/".join(col.split("/")[1:-2]) for col in df.columns]
frequency_parts = ["/".join(col.split("/")[1:-1]) for col in df.columns]
shard = pd.read_pickle(path)
self._assert_correct_sorting(shard)
return shard.drop(columns=["patient_id", "timestamp"])

filtered_columns = [
col
for col, code_part, freq_part in zip(df.columns, code_parts, frequency_parts)
if (self.codes_set is None or code_part in self.codes_set)
and (self.min_frequency_set is None or freq_part in self.min_frequency_set)
]
filtered_columns.extend([col for col in df.columns if col.endswith("/task")])
def _get_dynamic_shard_by_index(self, idx: int) -> tuple[sp.csr_matrix, np.ndarray]:
"""Load a specific shard of dynamic data from disk and return it as a sparse matrix after filtering
column inclusion."""
files = list(self.dynamic_data_path.glob(f"*/*/*/{self._data_shards[idx]}.pkl"))

return df[filtered_columns]
files = [file for file in files if self._filter_shard_files_on_window_and_aggs(file)]

def _load_dynamic_shard_by_index(self, idx: int) -> tuple[sp.csr_matrix, np.ndarray]:
dynamic_dfs = [self._load_dynamic_shard_from_file(file) for file in files]
dynamic_df = pd.concat(dynamic_dfs, axis=1)
return self._filter_shard_on_codes_and_freqs(dynamic_df)

def _get_shard_by_index(self, idx: int) -> tuple[sp.csr_matrix, np.ndarray]:
"""Load a specific shard of data from disk and concatenate with static data.
Args:
Expand All @@ -198,28 +152,84 @@ def _load_dynamic_shard_by_index(self, idx: int) -> tuple[sp.csr_matrix, np.ndar
- y (numpy.ndarray): Labels.
"""

files = list(self.dynamic_data_path.glob(f"*/*/*/{self._data_shards[idx]}.pkl"))

files = [file for file in files if self._validate_shard_file_inclusion(file)]

dynamic_dfs = [self._get_sparse_dynamic_shard_from_file(file) for file in files]
dynamic_df = pd.concat(dynamic_dfs, axis=1)
dynamic_df = self._filter_df(dynamic_df)
dynamic_df = self._get_dynamic_shard_by_index(idx)

# TODO: add in some type checking etc for safety
static_df = self._get_static_shard_by_index(idx)

task_df = self._get_task(idx)
task_df = self._get_task_by_index(idx)
task_df = task_df.rename(
columns={col: f"{col}/task" for col in task_df.columns if col not in ["patient_id", "timestamp"]}
)
df = pd.merge(task_df, static_df, on=["patient_id"], how="left")
self._assert_correct_sorting(df)
df = self._filter_df(df)
df = self._filter_shard_on_codes_and_freqs(df)
df = pd.concat([df, dynamic_df], axis=1)

return self._sparsify_shard(df)

def _sparsify_shard(self, df: pd.DataFrame) -> tuple[sp.csc_matrix, np.ndarray]:
"""Make X and y as scipy sparse arrays for XGBoost.
Args:
- df (pandas.DataFrame): Data frame to sparsify.
Returns:
- tuple[scipy.sparse.csr_matrix, numpy.ndarray]: Tuple of feature data and labels.
"""
labels = df.loc[:, [col for col in df.columns if col.endswith("/task")]]
data = df.drop(columns=labels.columns)
for col in data.columns:
if not isinstance(data[col].dtype, pd.SparseDtype):
data[col] = pd.arrays.SparseArray(data[col])
sparse_matrix = data.sparse.to_coo()
return csr_matrix(sparse_matrix), labels.values

def _filter_shard_files_on_window_and_aggs(self, file: Path) -> bool:
parts = file.relative_to(self.dynamic_data_path).parts
if not parts:
return False

windows_part = parts[0]
aggs_part = "/".join(parts[1:-1])

return (self.window_set is None or windows_part in self.window_set) and (
self.aggs_set is None or aggs_part in self.aggs_set
)

def _filter_shard_on_codes_and_freqs(self, df: pd.DataFrame) -> pd.DataFrame:
"""Filter the dynamic data frame based on the inclusion sets.
Args:
- df (pd.DataFrame): Data frame to filter.
Returns:
- pd.DataFrame: Filtered data frame.
"""
code_parts = ["/".join(col.split("/")[1:-2]) for col in df.columns]
frequency_parts = ["/".join(col.split("/")[1:-1]) for col in df.columns]

filtered_columns = [
col
for col, code_part, freq_part in zip(df.columns, code_parts, frequency_parts)
if (self.codes_set is None or code_part in self.codes_set)
and (self.min_frequency_set is None or freq_part in self.min_frequency_set)
]
filtered_columns.extend([col for col in df.columns if col.endswith("/task")])

return df[filtered_columns]

def _assert_correct_sorting(self, shard: pd.DataFrame):
"""Assert that the shard is sorted correctly."""
if "timestamp" in shard.columns:
sort_columns = ["patient_id", "timestamp"]
else:
sort_columns = ["patient_id"]
assert shard[sort_columns].equals(shard[sort_columns].sort_values(by=sort_columns)), (
"Shard is not sorted on correctly. "
"Please ensure that the data is sorted on patient_id and timestamp, if applicable."
)

def next(self, input_data: Callable):
"""Advance the iterator by 1 step and pass the data to XGBoost. This function is called by XGBoost
during the construction of ``DMatrix``
Expand All @@ -236,7 +246,7 @@ def next(self, input_data: Callable):

# input_data is a function passed in by XGBoost who has the exact same signature of
# ``DMatrix``
X, y = self._load_dynamic_shard_by_index(self._it) # self._data_shards[self._it])
X, y = self._get_shard_by_index(self._it) # self._data_shards[self._it])
input_data(data=X, label=y)
self._it += 1
# Return 1 to let XGBoost know we haven't seen all the files yet.
Expand All @@ -255,7 +265,7 @@ def collect_in_memory(self) -> tuple[sp.csr_matrix, np.ndarray]:
X = []
y = []
for i in range(len(self._data_shards)):
X_, y_ = self._load_dynamic_shard_by_index(i)
X_, y_ = self._get_shard_by_index(i)
X.append(X_)
y.append(y_)

Expand Down Expand Up @@ -288,6 +298,8 @@ def __init__(self, cfg: DictConfig):
def train(self):
"""Train the model."""
self._build()
# TODO: add in eval, early stopping, etc.
# TODO: check for Nan and inf in labels!
self.model = xgb.train(
OmegaConf.to_container(self.cfg.model), self.dtrain
) # do we want eval and things?
Expand Down

0 comments on commit db18dc5

Please sign in to comment.