Skip to content

Commit

Permalink
Xgboost is able to load all concatenated windows and aggregations. Fi…
Browse files Browse the repository at this point in the history
…xed bugs related to event ids and column ids being incorrect.
  • Loading branch information
Oufattole committed Jun 2, 2024
1 parent 3a412a0 commit a4f1843
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 113 deletions.
1 change: 1 addition & 0 deletions configs/xgboost_sweep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ do_overwrite: False
do_update: True
seed: 1
tqdm: True
test: False

model:
booster: gbtree
Expand Down
29 changes: 27 additions & 2 deletions hf_cohort/aces_task_extraction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Setup Conda environment as described here: https://github.com/justin13601/ACES
"""
import json
from pathlib import Path

import hydra
Expand All @@ -9,6 +10,29 @@
from tqdm import tqdm


def get_events_df(shard_df: pl.DataFrame, feature_columns) -> pl.DataFrame:
"""Extracts Events DataFrame with one row per observation (timestamps can be duplicated)"""
# Filter out feature_columns that were not present in the training set
raw_feature_columns = ["/".join(c.split("/")[:-1]) for c in feature_columns]
shard_df = shard_df.filter(pl.col("code").is_in(raw_feature_columns))
# Drop rows with missing timestamp or code to get events
ts_shard_df = shard_df.drop_nulls(subset=["timestamp", "code"])
return ts_shard_df


def get_unique_time_events_df(events_df: pl.DataFrame):
"""Updates Events DataFrame to have unique timestamps and sorted by patient_id and timestamp."""
assert events_df.select(pl.col("timestamp")).null_count().collect().item() == 0
# Check events_df is sorted - so it aligns with the ts_matrix we generate later in the pipeline
events_df = (
events_df.drop_nulls("timestamp")
.select(pl.col(["patient_id", "timestamp"]))
.unique(maintain_order=True)
)
assert events_df.sort(by=["patient_id", "timestamp"]).collect().equals(events_df.collect())
return events_df


@hydra.main(version_base=None, config_path="../configs", config_name="tabularize")
def main(cfg):
# create task configuration object
Expand Down Expand Up @@ -37,10 +61,11 @@ def main(cfg):
.rename({"trigger": "timestamp", "subject_id": "patient_id"})
.sort(by=["patient_id", "timestamp"])
)
feature_columns = json.read(Path(cfg.tabularized_data_dir) / "feature_columns.json")
data_df = pl.scan_parquet(in_fp)
data_df = data_df.unique(subset=["patient_id", "timestamp"]).sort(by=["patient_id", "timestamp"])
data_df = data_df.with_row_index("event_id")
data_df = get_unique_time_events_df(get_events_df(data_df, feature_columns))
data_df = data_df.drop(["code", "numerical_value"])
data_df = data_df.with_row_index("event_id")
output_df = label_df.lazy().join_asof(other=data_df, by="patient_id", on="timestamp")

# store it
Expand Down
4 changes: 1 addition & 3 deletions scripts/tabularize_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def tabularize_static_data(
np.random.shuffle(tabularization_tasks)

for shard_fp, agg in iter_wrapper(tabularization_tasks):
static_fp = f_name_resolver.get_flat_static_rep(
shard_fp.parent.stem, shard_fp.stem, agg.split("/")[-1]
)
static_fp = f_name_resolver.get_flat_static_rep(shard_fp.parent.stem, shard_fp.stem, agg)
if static_fp.exists() and not cfg.do_overwrite:
raise FileExistsError(f"do_overwrite is {cfg.do_overwrite} and {static_fp} exists!")

Expand Down
122 changes: 35 additions & 87 deletions scripts/xgboost_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from omegaconf import DictConfig, OmegaConf
from sklearn.metrics import mean_absolute_error

from MEDS_tabular_automl.file_name import FileNameResolver
from MEDS_tabular_automl.utils import get_feature_indices, load_matrix


class Iterator(xgb.DataIter):
def __init__(self, cfg: DictConfig, split: str = "train"):
Expand All @@ -24,14 +27,18 @@ def __init__(self, cfg: DictConfig, split: str = "train"):
- split (str): The data split to use ("train", "tuning", or "held_out").
"""
self.cfg = cfg
self.data_path = Path(cfg.tabularized_data_dir)
self.dynamic_data_path = self.data_path / "sparse" / split
self.task_data_path = self.data_path / "task" / split
self.file_name_resolver = FileNameResolver(cfg)
self.split = split
# self.data_path = Path(cfg.tabularized_data_dir)
# self.dynamic_data_path = self.data_path / "sparse" / split
# self.task_data_path = self.data_path / "task" / split
self._data_shards = sorted(
[shard.stem for shard in list(self.task_data_path.glob("*.parquet"))]
[shard.stem for shard in self.file_name_resolver.list_label_files(split)]
) # [2, 4, 5] #
self.valid_event_ids, self.labels = self.load_labels()
self.window_set, self.aggs_set, self.codes_set, self.num_features = self._get_inclusion_sets()
self.codes_set, self.num_features = self._get_code_set()
feature_columns = json.load(open(self.file_name_resolver.get_feature_columns_fp()))
self.agg_to_feature_ids = {agg: get_feature_indices(agg, feature_columns) for agg in cfg.aggs}

self._it = 0

Expand All @@ -48,7 +55,9 @@ def load_labels(self) -> tuple[Mapping[int, list], Mapping[int, list]]:
in the sparse matrix
dictionary from shard number to list of labels for these valid event ids
"""
label_fps = {shard: self.task_data_path / f"{shard}.parquet" for shard in self._data_shards}
label_fps = {
shard: self.file_name_resolver.get_label(self.split, shard) for shard in self._data_shards
}
cached_labels, cached_event_ids = dict(), dict()
for shard, label_fp in label_fps.items():
label_df = pl.scan_parquet(label_fp)
Expand All @@ -58,14 +67,14 @@ def load_labels(self) -> tuple[Mapping[int, list], Mapping[int, list]]:

def _get_code_set(self) -> set:
"""Get the set of codes to include in the data based on the configuration."""
with open(self.data_path / "feature_columns.json") as f:
with open(self.file_name_resolver.get_feature_columns_fp()) as f:
feature_columns = json.load(f)
feature_dict = {col: i for i, col in enumerate(feature_columns)}
if self.cfg.codes is not None:
codes_set = {feature_dict[code] for code in set(self.cfg.codes) if code in feature_dict}

if self.cfg.min_code_inclusion_frequency is not None:
with open(self.data_path / "feature_freqs.json") as f:
with open(self.file_name_resolver.get_feature_freqs_fp()) as f:
feature_freqs = json.load(f)
min_frequency_set = {
key for key, value in feature_freqs.items() if value >= self.cfg.min_code_inclusion_frequency
Expand All @@ -83,53 +92,6 @@ def _get_code_set(self) -> set:
# TODO: make sure we aren't filtering out static columns!!!
return list(codes_set), len(feature_columns)

def _get_inclusion_sets(self) -> tuple[set, set, np.array]:
"""Get the inclusion sets for aggregations, window sizes, and a mask for minimum code frequency.
Returns:
- Tuple[Optional[Set[str]], Optional[Set[str]], np.ndarray]: Tuple containing:
- Set of aggregations.
- Set of window sizes.
- Boolean array mask indicating which feature columns meet the inclusion criteria.
Examples:
>>> import tempfile
>>> from types import SimpleNamespace
>>> cfg = SimpleNamespace(
... aggs=["code/count", "value/sum"],
... window_sizes=None,
... codes=["code1", "code2", "value1"],
... min_code_inclusion_frequency=2
... )
>>> with tempfile.TemporaryDirectory() as tempdir:
... data_path = Path(tempdir)
... cfg.tabularized_data_dir = str(data_path)
... feature_columns = ["code1/code", "code2/code", "value1/value"]
... feature_freqs = {"code1": 3, "code2": 1, "value1": 5}
... with open(data_path / "feature_columns.json", "w") as f:
... json.dump(feature_columns, f)
... with open(data_path / "feature_freqs.json", "w") as f:
... json.dump(feature_freqs, f)
... iterator = Iterator(cfg)
... aggs_set, window_set, mask = iterator._get_inclusion_sets()
... assert aggs_set == {"code/count", "value/sum"}
... assert window_set == None
... assert np.array_equal(mask, [True, False, True])
"""

window_set = None
aggs_set = None

if self.cfg.aggs is not None:
aggs_set = set(self.cfg.aggs)

if self.cfg.window_sizes is not None:
window_set = set(self.cfg.window_sizes)

codes_set, num_features = self._get_code_set()

return sorted(window_set), sorted(aggs_set), sorted(codes_set), num_features

def _load_dynamic_shard_from_file(self, path: Path, idx: int) -> sp.csc_matrix:
"""Load a sparse shard into memory.
Expand Down Expand Up @@ -170,16 +132,13 @@ def _load_dynamic_shard_from_file(self, path: Path, idx: int) -> sp.csc_matrix:
... assert np.array_equal(loaded_shard.indptr, expected_csr.indptr)
"""
# column_shard is of form event_idx, feature_idx, value
column_shard = np.load(path).T # TODO: Fix this!!!

shard = sp.csc_matrix(
(column_shard[:, 0], (column_shard[:, 1], column_shard[:, 2])),
shape=(
max(self.valid_event_ids[self._data_shards[idx]], column_shard[:, 1]) + 1,
self.num_features,
),
)
return self._filter_shard_on_codes_and_freqs(shard)
matrix = load_matrix(path)
if path.stem in ["first", "present"]:
agg = f"static/{path.stem}"
else:
agg = f"{path.parent.stem}/{path.stem}"

return self._filter_shard_on_codes_and_freqs(agg, sp.csc_matrix(matrix))

def _get_dynamic_shard_by_index(self, idx: int) -> sp.csr_matrix:
"""Load a specific shard of dynamic data from disk and return it as a sparse matrix after filtering
Expand All @@ -191,11 +150,14 @@ def _get_dynamic_shard_by_index(self, idx: int) -> sp.csr_matrix:
Returns:
- sp.csr_matrix: Filtered sparse matrix.
"""
# TODO Nassim Fix this guy
# get all window_size x aggreagation files using the file resolver
files = self.file_name_resolver.get_model_files(
self.cfg.window_sizes, self.cfg.aggs, self.split, self._data_shards[idx]
)
assert all([file.exists() for file in files])
shard_name = self._data_shards[idx]
shard_pattern = f"*/*/*/{shard_name}.npy"
files = self.dynamic_data_path.glob(shard_pattern)
valid_files = sorted(file for file in files if self._filter_shard_files_on_window_and_aggs(file))
dynamic_csrs = [self._load_dynamic_shard_from_file(file, idx) for file in valid_files]
dynamic_csrs = [self._load_dynamic_shard_from_file(file, idx) for file in files]
combined_csr = sp.hstack(dynamic_csrs, format="csr") # TODO: check this
# Filter Rows
valid_indices = self.valid_event_ids[shard_name]
Expand All @@ -219,23 +181,7 @@ def _get_shard_by_index(self, idx: int) -> tuple[sp.csr_matrix, np.ndarray]:
logger.debug(f"Task data loading took {datetime.now() - time}")
return dynamic_df, label_df

def _filter_shard_files_on_window_and_aggs(self, file: Path) -> bool:
parts = file.relative_to(self.dynamic_data_path).parts
if len(parts) < 2:
return False

windows_part = parts[0]
aggs_part = "/".join(parts[1:-1])

if self.window_set is not None and windows_part not in self.window_set:
return False

if self.aggs_set is not None and aggs_part not in self.aggs_set:
return False

return True

def _filter_shard_on_codes_and_freqs(self, df: sp.csc_matrix) -> sp.csc_matrix:
def _filter_shard_on_codes_and_freqs(self, agg: str, df: sp.csc_matrix) -> sp.csc_matrix:
"""Filter the dynamic data frame based on the inclusion sets. Given the codes_mask, filter the data
frame to only include columns that are True in the mask.
Expand All @@ -247,7 +193,9 @@ def _filter_shard_on_codes_and_freqs(self, df: sp.csc_matrix) -> sp.csc_matrix:
"""
if self.codes_set is None:
return df
return df[:, self.codes_set] # [:, list({index for index in self.codes_set if index < df.shape[1]})]
feature_ids = self.agg_to_feature_ids[agg]
code_mask = [True if idx in self.codes_set else False for idx in feature_ids]
return df[:, code_mask] # [:, list({index for index in self.codes_set if index < df.shape[1]})]

def next(self, input_data: Callable):
"""Advance the iterator by 1 step and pass the data to XGBoost. This function is called by XGBoost
Expand Down
30 changes: 29 additions & 1 deletion src/MEDS_tabular_automl/file_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def get_ts_dir(self):
def get_sparse_dir(self):
return self.tabularize_dir / "sparse"

def get_label_dir(self):
return self.tabularize_dir / "task"

def get_feature_columns_fp(self):
return self.tabularize_dir / "feature_columns.json"

Expand All @@ -37,7 +40,8 @@ def get_meds_shard(self, split: str, shard_num: int):

def get_flat_static_rep(self, split: str, shard_num: int, agg: str):
# Given a shard number, returns the static representation path
return self.get_static_dir() / split / f"{shard_num}" / f"{agg}.npz"
agg_name = agg.split("/")[-1]
return self.get_static_dir() / split / f"{shard_num}" / f"{agg_name}.npz"

def get_flat_ts_rep(self, split: str, shard_num: int, window_size: int, agg: str):
# Given a shard number, returns the time series representation path
Expand All @@ -47,6 +51,10 @@ def get_flat_sparse_rep(self, split: str, shard_num: int, window_size: int, agg:
# Given a shard number, returns the sparse representation path
return self.get_sparse_dir() / split / f"{shard_num}" / f"{window_size}" / f"{agg}.npz"

def get_label(self, split: str, shard_num: int):
# Given a shard number, returns the label path
return self.get_label_dir() / split / f"{shard_num}.parquet"

def list_meds_files(self, split=None):
# List all MEDS files
if split:
Expand All @@ -70,3 +78,23 @@ def list_sparse_files(self, split=None):
if split:
return sorted(list(self.get_sparse_dir().glob(f"{split}/*/*.npz")))
return sorted(list(self.get_sparse_dir().glob("*/*/*.npz")))

def list_label_files(self, split=None):
# List all label files
if split:
return sorted(list(self.get_label_dir().glob(f"{split}/*.parquet")))
return sorted(list(self.get_label_dir().glob("*/*.parquet")))

def get_model_files(self, window_sizes, aggs, split, shard_num: int):
# Given a shard number, returns the model files
model_files = []
for window_size in window_sizes:
for agg in aggs:
if agg.startswith("static"):
continue
else:
model_files.append(self.get_flat_ts_rep(split, shard_num, window_size, agg))
for agg in aggs:
if agg.startswith("static"):
model_files.append(self.get_flat_static_rep(split, shard_num, agg))
return sorted(model_files)
4 changes: 2 additions & 2 deletions src/MEDS_tabular_automl/generate_static_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
STATIC_VALUE_AGGREGATION,
get_events_df,
get_feature_names,
get_unique_time_events_df,
parse_static_feature_column,
)

Expand Down Expand Up @@ -53,7 +54,6 @@ def get_sparse_static_rep(static_features, static_df, meds_df, feature_columns)
Returns:
- pd.DataFrame: A merged dataframe containing static and time-series features.
"""
# TODO - Eventually do this duplication at the task specific stage after filtering patients and features
# Make static data sparse and merge it with the time-series data
logger.info("Make static data sparse and merge it with the time-series data")
# Check static_df is sorted and unique
Expand All @@ -62,7 +62,7 @@ def get_sparse_static_rep(static_features, static_df, meds_df, feature_columns)
static_df.select(pl.len()).collect().item()
== static_df.select(pl.col("patient_id").n_unique()).collect().item()
)
meds_df = get_events_df(meds_df, feature_columns)
meds_df = get_unique_time_events_df(get_events_df(meds_df, feature_columns))

# load static data as sparse matrix
static_matrix = convert_to_matrix(
Expand Down
2 changes: 1 addition & 1 deletion src/MEDS_tabular_automl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def get_events_df(shard_df: pl.DataFrame, feature_columns) -> pl.DataFrame:

def get_unique_time_events_df(events_df: pl.DataFrame):
"""Updates Events DataFrame to have unique timestamps and sorted by patient_id and timestamp."""
assert events_df.select(pl.col("timestamp")).is_nan().any().collect().item() == 0
assert events_df.select(pl.col("timestamp")).null_count().collect().item() == 0
# Check events_df is sorted - so it aligns with the ts_matrix we generate later in the pipeline
events_df = (
events_df.drop_nulls("timestamp")
Expand Down
Loading

0 comments on commit a4f1843

Please sign in to comment.