Skip to content

Commit

Permalink
merged xgboost code
Browse files Browse the repository at this point in the history
  • Loading branch information
Oufattole committed Jun 1, 2024
2 parents e8d64fd + db18dc5 commit 5b2f7f7
Show file tree
Hide file tree
Showing 4 changed files with 457 additions and 2 deletions.
67 changes: 67 additions & 0 deletions configs/xgboost_sweep.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Raw data
base_dir: /storage/teya/fake
MEDS_cohort_dir: ${base_dir}/MEDS_cohort
tabularized_data_dir: ${base_dir}/flat_reps
model_dir: ${base_dir}/models

# Pre-processing
min_code_inclusion_frequency: 1
window_sizes: [30d]
codes: null
aggs:
- "code/count"
- "value/sum"

# Sharding
n_patients_per_sub_shard: 2

# Misc
do_overwrite: True
do_update: True
seed: 1
tqdm: False

model:
booster: gbtree
device: cpu
nthread: 4
max_depth: 6
eta: 0.3
gamma: 0
subsample: 1
lambda: 1
alpha: 0
tree_method: hist
objective: reg:squaredlogerror

iterator:
keep_data_in_memory: True
keep_static_data_in_memory: True

# Hydra settings for sweep
defaults:
- override hydra/sweeper: optuna
- override hydra/sweeper/sampler: tpe

hydra:
mode: MULTIRUN
verbose: False
sweep:
dir: ${tabularized_data_dir}/.logs/etl/${now:%Y-%m-%d_%H-%M-%S}
run:
dir: ${tabularized_data_dir}/.logs/etl/${now:%Y-%m-%d_%H-%M-%S}

# Optuna Sweeper
sweeper:
sampler:
seed: 1
storage: null
study_name: tabularize_study_${now:%Y-%m-%d_%H-%M-%S}
direction: minimize
n_trials: 10

# Define search space for Optuna
params:
window_sizes: choice([30d, 365d, full], [30d, full], [30d])
# iterator.keep_static_data_in_memory: choice([True], [False])
# iterator.keep_data_in_memory: choice([True], [False])
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = ["polars", "pyarrow", "loguru", "hydra-core", "numpy", "scipy", "pandas", "numba", "tqdm"]
dependencies = ["polars", "pyarrow", "loguru", "hydra-core", "numpy", "scipy", "pandas", "numba", "tqdm", "xgboost"]

[project.optional-dependencies]
dev = ["pre-commit"]
Expand Down
Loading

0 comments on commit 5b2f7f7

Please sign in to comment.