Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging to form final, release v0.0.1 branch (eventually) #7

Merged
merged 10 commits into from
Jun 12, 2024
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ meds-tab-tabularize-static = "MEDS_tabular_automl.scripts.tabularize_static:main
meds-tab-tabularize-time-series = "MEDS_tabular_automl.scripts.tabularize_time_series:main"
meds-tab-cache-task = "MEDS_tabular_automl.scripts.cache_task:main"
meds-tab-xgboost = "MEDS_tabular_automl.scripts.launch_xgboost:main"
meds-tab-xgboost-sweep = "MEDS_tabular_automl.scripts.sweep_xgboost:main"
generate-permutations = "MEDS_tabular_automl.scripts.generate_permutations:main"


[project.optional-dependencies]
dev = ["pre-commit"]
Expand Down
64 changes: 19 additions & 45 deletions src/MEDS_tabular_automl/configs/launch_xgboost.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
defaults:
- default
- tabularization: default
- override hydra/sweeper: optuna
- override hydra/sweeper/sampler: tpe
- _self_

task_name: task
Expand Down Expand Up @@ -28,52 +30,24 @@ model_params:
keep_data_in_memory: True
binarize_task: True

# Define search space for Optuna
optuna:
study_name: xgboost_sweep_${now:%Y-%m-%d_%H-%M-%S}
storage: null
load_if_exists: False
direction: minimize
sampler: null
pruner: null
hydra:
verbose: False
sweep:
dir: ${output_dir}/.logs/
run:
dir: ${output_dir}/.logs/

n_trials: 10
n_jobs: 1
show_progress_bar: False
# Optuna Sweeper
sweeper:
sampler:
seed: 1
study_name: null #study_${now:%Y-%m-%d_%H-%M-%S}
storage: null
direction: minimize
n_trials: 10

params:
suggest_categorical:
window_sizes: ${generate_permutations:${tabularization.window_sizes}}
aggs: ${generate_permutations:${tabularization.aggs}}
suggest_float:
eta:
low: .001
high: 1
log: True
lambda:
low: .001
high: 1
log: True
alpha:
low: .001
high: 1
log: True
subsample:
low: 0.5
high: 1
min_child_weight:
low: 1e-2
high: 100
suggest_int:
num_boost_round:
low: 10
high: 1000
max_depth:
low: 2
high: 16
min_code_inclusion_frequency:
low: 10
high: 1_000_000
log: True
# Define search space for Optuna
params:
tabularization.window_sizes: choice([30d], [30d, 365d], [365d, full])

name: launch_xgboost
18 changes: 13 additions & 5 deletions src/MEDS_tabular_automl/describe_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,21 @@ def filter_to_codes(
if allowed_codes is None:
allowed_codes = get_feature_columns(code_metadata_fp)
feature_freqs = get_feature_freqs(code_metadata_fp)
allowed_codes_set = set(allowed_codes)

code_freqs = {
code: freq
filtered_codes = [
code
for code, freq in feature_freqs.items()
if (freq >= min_code_inclusion_frequency and code in set(allowed_codes))
}
return sorted([code for code, freq in code_freqs.items() if freq >= min_code_inclusion_frequency])
if freq >= min_code_inclusion_frequency and code in allowed_codes_set
]
return sorted(filtered_codes)

# code_freqs = {
# code: freq
# for code, freq in feature_freqs.items()
# if (freq >= min_code_inclusion_frequency and code in set(allowed_codes))
# }
# return sorted([code for code, freq in code_freqs.items() if freq >= min_code_inclusion_frequency])


# OmegaConf.register_new_resolver("filter_to_codes", filter_to_codes)
Expand Down
47 changes: 24 additions & 23 deletions src/MEDS_tabular_automl/generate_summarized_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from scipy.sparse import coo_array, csr_array, sparray

from MEDS_tabular_automl.generate_ts_features import get_feature_names, get_flat_ts_rep
from MEDS_tabular_automl.utils import CODE_AGGREGATIONS, VALUE_AGGREGATIONS, load_tqdm
from MEDS_tabular_automl.describe_codes import get_feature_columns
from MEDS_tabular_automl.utils import CODE_AGGREGATIONS, VALUE_AGGREGATIONS, load_tqdm, get_min_dtype


def sparse_aggregate(sparse_matrix, agg):
Expand Down Expand Up @@ -45,18 +46,17 @@ def aggregate_matrix(windows, matrix, agg, num_features, use_tqdm=False):
"""Aggregate the matrix based on the windows."""
tqdm = load_tqdm(use_tqdm)
agg = agg.split("/")[-1]
dtype = np.float32
matrix = csr_array(matrix.astype(dtype))
if agg.startswith("sum"):
out_dtype = np.float32
else:
out_dtype = np.int32
matrix = csr_array(matrix)
# if agg.startswith("sum"):
# out_dtype = np.float32
# else:
# out_dtype = np.int32
data, row, col = [], [], []
for i, window in tqdm(enumerate(windows.iter_rows(named=True)), total=len(windows)):
min_index = window["min_index"]
max_index = window["max_index"]
subset_matrix = matrix[min_index : max_index + 1, :]
agg_matrix = sparse_aggregate(subset_matrix, agg).astype(out_dtype)
agg_matrix = sparse_aggregate(subset_matrix, agg)
if isinstance(agg_matrix, np.ndarray):
nozero_ind = np.nonzero(agg_matrix)[0]
col.append(nozero_ind)
Expand All @@ -69,12 +69,16 @@ def aggregate_matrix(windows, matrix, agg, num_features, use_tqdm=False):
else:
raise TypeError(f"Invalid matrix type {type(agg_matrix)}")
row = np.concatenate(row)
out_matrix = coo_array(
(np.concatenate(data), (row, np.concatenate(col))),
dtype=out_dtype,
data = np.concatenate(data)
col = np.concatenate(col)
row = row.astype(get_min_dtype(row), copy=False)
col = col.astype(get_min_dtype(col), copy=False)
data = data.astype(get_min_dtype(data), copy=False)
out_matrix = csr_array(
(data, (row, col)),
shape=(windows.shape[0], num_features),
)
return csr_array(out_matrix)
return out_matrix


def compute_agg(index_df, matrix: sparray, window_size: str, agg: str, num_features: int, use_tqdm=False):
Expand Down Expand Up @@ -249,17 +253,14 @@ def generate_summary(
import json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import: json.

- import json
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import json
Tools
Ruff

253-253: json imported but unused (F401)

Remove unused import: json

from pathlib import Path

feature_columns = json.load(
open(
Path("/storage/shared/meds_tabular_ml/ebcl_dataset/processed/tabularize") / "feature_columns.json"
)
)
df = pl.scan_parquet(
Path("/storage/shared/meds_tabular_ml/ebcl_dataset/processed")
/ "final_cohort"
/ "train"
/ "2.parquet"
)
# feature_columns_fp = Path("/storage/shared/meds_tabular_ml/mimiciv_dataset/mimiciv_MEDS") / "tabularized_code_metadata.parquet"
# shard_fp = Path("/storage/shared/meds_tabular_ml/mimiciv_dataset/mimiciv_MEDS/final_cohort/train/0.parquet")

feature_columns_fp = Path("/storage/shared/meds_tabular_ml/ebcl_dataset/processed") / "tabularized_code_metadata.parquet"
shard_fp = Path("/storage/shared/meds_tabular_ml/ebcl_dataset/processed/final_cohort/train/0.parquet")

feature_columns = get_feature_columns(feature_columns_fp)
df = pl.scan_parquet(shard_fp)
agg = "code/count"
index_df, sparse_matrix = get_flat_ts_rep(agg, feature_columns, df)
generate_summary(
Expand Down
6 changes: 3 additions & 3 deletions src/MEDS_tabular_automl/scripts/cache_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@

def generate_row_cached_matrix(matrix, label_df):
"""Generates row-cached matrix for a given matrix and label_df."""
label_len = label_df.select(pl.len()).collect().item()
if not matrix.shape[0] == label_len:
label_len = label_df.select(pl.col("event_id").max()).collect().item()
if matrix.shape[0] <= label_len:
raise ValueError(
f"Matrix and label_df must have the same number of rows: {matrix.shape[0]} != {label_len}"
f"Label_df event_ids must be valid indexes of sparse matrix: {matrix.shape[0]} <= {label_len}"
)
csr = sp.csr_array(matrix)
valid_ids = label_df.select(pl.col("event_id")).collect().to_series().to_numpy()
Expand Down
47 changes: 47 additions & 0 deletions src/MEDS_tabular_automl/scripts/generate_permutations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import sys
from itertools import combinations


def format_print(permutations):
"""
Args:
permutations: List of all possible permutations of length > 1

Example:
>>> format_print([('2',), ('2', '3'), ('2', '3', '4'), ('2', '4'), ('3',), ('3', '4'), ('4',)])
[2],[2,3],[2,3,4],[2,4],[3],[3,4],[4]
"""
out_str = ""
for item in permutations:
out_str += f"[{','.join(item)}],"
out_str = out_str[:-1]
print(out_str)


def get_permutations(list_of_options):
"""Generate all possible permutations of a list of options passed as an arg.

Args:
- list_of_options (list): List of options.

Returns:
- list: List of all possible permutations of length > 1

Example:
>>> get_permutations(['2', '3', '4'])
[2],[2,3],[2,3,4],[2,4],[3],[3,4],[4]
"""
permutations = []
for i in range(1, len(list_of_options) + 1):
permutations.extend(list(combinations(list_of_options, r=i)))
format_print(sorted(permutations))
Comment on lines +21 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_permutations function efficiently generates permutations. Consider handling edge cases where the input list might be empty.

+ if not list_of_options:
+     return []
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_permutations(list_of_options):
"""Generate all possible permutations of a list of options passed as an arg.
Args:
- list_of_options (list): List of options.
Returns:
- list: List of all possible permutations of length > 1
Example:
>>> get_permutations(['2', '3', '4'])
[2],[2,3],[2,3,4],[2,4],[3],[3,4],[4]
"""
permutations = []
for i in range(1, len(list_of_options) + 1):
permutations.extend(list(combinations(list_of_options, r=i)))
format_print(sorted(permutations))
def get_permutations(list_of_options):
"""Generate all possible permutations of a list of options passed as an arg.
Args:
- list_of_options (list): List of options.
Returns:
- list: List of all possible permutations of length > 1
Example:
>>> get_permutations(['2', '3', '4'])
[2],[2,3],[2,3,4],[2,4],[3],[3,4],[4]
"""
if not list_of_options:
return []
permutations = []
for i in range(1, len(list_of_options) + 1):
permutations.extend(list(combinations(list_of_options, r=i)))
format_print(sorted(permutations))



def main():
"""Generate all possible permutations of a list of options."""
list_of_options = list(sys.argv[1].strip("[]").split(","))
get_permutations(list_of_options)


if __name__ == "__main__":
main()
Comment on lines +40 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tip

Codebase Verification

The current implementation of the main function in generate_permutations.py does not include any error handling for incorrect input formats. It directly processes sys.argv[1] without validating the input.

  • Consider adding error handling to manage incorrect input formats in the main function.
Analysis chain

The main function correctly parses command-line arguments to generate permutations. Ensure proper error handling for incorrect input formats.

Scripts executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Description: Verify that the command-line arguments are correctly parsed and handled.

# Test: Search for the command-line argument handling. Expect: Proper error handling for incorrect formats.
rg --type python $'sys.argv'

Length of output: 58



Script:

#!/bin/bash
# Description: Verify that the command-line arguments are correctly parsed and handled.

# Test: Search for the command-line argument handling. Expect: Proper error handling for incorrect formats.
rg 'sys.argv' --glob '*.py'

Length of output: 147



Script:

#!/bin/bash
# Description: Verify error handling for incorrect input formats in generate_permutations.py.

# Test: Search for try-except blocks or input validation in generate_permutations.py.
rg 'try|except|if' src/MEDS_tabular_automl/scripts/generate_permutations.py

Length of output: 104

36 changes: 14 additions & 22 deletions src/MEDS_tabular_automl/scripts/launch_xgboost.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Callable, Mapping
from datetime import datetime
from importlib.resources import files
from pathlib import Path

Expand All @@ -12,7 +13,7 @@
from omegaconf import DictConfig, OmegaConf
from sklearn.metrics import roc_auc_score

from MEDS_tabular_automl.describe_codes import get_feature_columns, get_feature_freqs
from MEDS_tabular_automl.describe_codes import get_feature_columns
from MEDS_tabular_automl.file_name import get_model_files, list_subdir_files
from MEDS_tabular_automl.utils import get_feature_indices, hydra_loguru_init

Expand Down Expand Up @@ -188,18 +189,8 @@ def _get_dynamic_shard_by_index(self, idx: int) -> sp.csc_matrix:

dynamic_cscs = [self._load_dynamic_shard_from_file(file, idx) for file in files]

fn_name = "_get_dynamic_shard_by_index"
hstack_key = f"{fn_name}/hstack"
self._register_start(key=hstack_key)

combined_csc = sp.hstack(dynamic_cscs, format="csc") # TODO: check this
# self._register_end(key=hstack_key)
# # Filter Rows
# valid_indices = self.valid_event_ids[shard_name]
# filter_key = f"{fn_name}/filter"
# self._register_start(key=filter_key)
# out = combined_csc[valid_indices, :]
# self._register_end(key=filter_key)
combined_csc = sp.hstack(dynamic_cscs, format="csc")

return combined_csc

@TimeableMixin.TimeAs
Expand Down Expand Up @@ -388,30 +379,31 @@ def main(cfg: DictConfig) -> float:
Returns:
- float: Evaluation result.
"""

print(OmegaConf.to_yaml(cfg))
if not cfg.loguru_init:
hydra_loguru_init()

model = XGBoostModel(cfg)
model.train()
auc = model.evaluate()
logger.info(f"AUC: {auc}")

print(
"Time Profiling for window sizes ",
f"{cfg.tabularization.window_sizes} and min ",
"code frequency of {cfg.tabularization.min_code_inclusion_frequency}:",
f"code frequency of {cfg.tabularization.min_code_inclusion_frequency}:",
)
print("Train Time: \n", model._profile_durations())
print("Train Iterator Time: \n", model.itrain._profile_durations())
print("Tuning Iterator Time: \n", model.ituning._profile_durations())
print("Held Out Iterator Time: \n", model.iheld_out._profile_durations())
# print("Train Iterator Time: \n", model.itrain._profile_durations())
# print("Tuning Iterator Time: \n", model.ituning._profile_durations())
# print("Held Out Iterator Time: \n", model.iheld_out._profile_durations())

# save model
save_dir = Path(cfg.output_dir)
save_dir.mkdir(parents=True, exist_ok=True)

logger.info(f"Saving the model to directory: {save_dir}")
model.model.save_model(save_dir / "model.json")
auc = model.evaluate()
logger.info(f"AUC: {auc}")
model_time = datetime.now().strftime("%H%M%S%f")
model.model.save_model(save_dir / f"{auc:.4f}_model_{model_time}.json")
return auc


Expand Down
9 changes: 5 additions & 4 deletions src/MEDS_tabular_automl/scripts/tabularize_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,11 @@ def compute_fn(_):
cfg.input_code_metadata_fp,
)
feature_freqs = get_feature_freqs(cfg.input_code_metadata_fp)
filtered_feeature_freqs = {
code: count for code, count in feature_freqs.items() if code in filtered_feature_columns
filtered_feature_columns_set = set(filtered_feature_columns)
filtered_feature_freqs = {
code: count for code, count in feature_freqs.items() if code in filtered_feature_columns_set
}
return convert_to_df(filtered_feeature_freqs)
return convert_to_df(filtered_feature_freqs)

def write_fn(data, out_fp):
data.write_parquet(out_fp)
Expand All @@ -116,6 +117,7 @@ def write_fn(data, out_fp):
do_overwrite=cfg.do_overwrite,
do_return=False,
)

# Step 2: Produce static data representation
meds_shard_fps = list_subdir_files(cfg.input_dir, "parquet")
feature_columns = get_feature_columns(cfg.tabularization.filtered_code_metadata_fp)
Expand All @@ -125,7 +127,6 @@ def write_fn(data, out_fp):
static_aggs = [agg for agg in aggs if agg in [STATIC_CODE_AGGREGATION, STATIC_VALUE_AGGREGATION]]
tabularization_tasks = list(product(meds_shard_fps, static_aggs))
np.random.shuffle(tabularization_tasks)

for shard_fp, agg in iter_wrapper(tabularization_tasks):
out_fp = (
Path(cfg.output_dir) / get_shard_prefix(cfg.input_dir, shard_fp) / "none" / agg
Expand Down
7 changes: 7 additions & 0 deletions src/MEDS_tabular_automl/scripts/tabularize_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
load_tqdm,
write_df,
)
import gc

config_yaml = files("MEDS_tabular_automl").joinpath("configs/tabularization.yaml")
if not config_yaml.is_file():
Expand Down Expand Up @@ -101,13 +102,19 @@ def compute_fn(shard_df):
agg,
)
assert summary_df.shape[1] > 0, "No data found in the summarized dataframe"
del index_df
del sparse_matrix
gc.collect()

logger.info("Writing pivot file")
return summary_df

def write_fn(out_matrix, out_fp):
coo_matrix = out_matrix.tocoo()
write_df(coo_matrix, out_fp, do_overwrite=cfg.do_overwrite)
del coo_matrix
del out_matrix
gc.collect()

rwlock_wrap(
shard_fp,
Expand Down
Loading
Loading