Skip to content

Commit

Permalink
Merge branch 'xgboost' into esgpt_caching
Browse files Browse the repository at this point in the history
  • Loading branch information
Oufattole committed Jun 1, 2024
2 parents b9d057b + 795b532 commit 85f38b5
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 31 deletions.
3 changes: 1 addition & 2 deletions configs/xgboost_sweep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ tqdm: True
model:
booster: gbtree
device: cpu
epochs: 1
tree_method: hist
objective: binary:logistic
objective: reg:squarederror

iterator:
keep_data_in_memory: False
Expand Down
87 changes: 58 additions & 29 deletions scripts/xgboost_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable, Mapping
from datetime import datetime
from pathlib import Path
from timeit import timeit

import hydra
import numpy as np
Expand All @@ -25,11 +26,12 @@ def __init__(self, cfg: DictConfig, split: str = "train"):
self.cfg = cfg
self.data_path = Path(cfg.tabularized_data_dir)
self.dynamic_data_path = self.data_path / "sparse" / split
self.label_data_path = self.data_path / "task" / split
self._data_shards = [4] # sort([shard.stem for shard in list(self.static_data_path.glob("*."))])
self.task_data_path = self.data_path / "task" / split
self._data_shards = sorted(
[shard.stem for shard in list(self.task_data_path.glob("*.parquet"))]
) # [2, 4, 5] #
self.valid_event_ids, self.labels = self.load_labels()
# TODO: need to fix this path/logic
self.window_set, self.aggs_set, self.codes_set = self._get_inclusion_sets()
self.window_set, self.aggs_set, self.codes_set, self.num_features = self._get_inclusion_sets()

self._it = 0

Expand All @@ -46,7 +48,7 @@ def load_labels(self) -> tuple[Mapping[int, list], Mapping[int, list]]:
in the sparse matrix
dictionary from shard number to list of labels for these valid event ids
"""
label_fps = {shard: self.label_data_path / f"{shard}.parquet" for shard in self._data_shards}
label_fps = {shard: self.task_data_path / f"{shard}.parquet" for shard in self._data_shards}
cached_labels, cached_event_ids = dict(), dict()
for shard, label_fp in label_fps.items():
label_df = pl.scan_parquet(label_fp)
Expand Down Expand Up @@ -78,7 +80,8 @@ def _get_code_set(self) -> set:
codes_set = frequency_set
else:
codes_set = None # set(feature_columns)
return codes_set
# TODO: make sure we aren't filtering out static columns!!!
return list(codes_set), len(feature_columns)

def _get_inclusion_sets(self) -> tuple[set, set, np.array]:
"""Get the inclusion sets for aggregations, window sizes, and a mask for minimum code frequency.
Expand Down Expand Up @@ -123,9 +126,11 @@ def _get_inclusion_sets(self) -> tuple[set, set, np.array]:
if self.cfg.window_sizes is not None:
window_set = set(self.cfg.window_sizes)

return aggs_set, window_set, self._get_code_set()
codes_set, num_features = self._get_code_set()

def _load_dynamic_shard_from_file(self, path: Path) -> sp.csc_matrix:
return sorted(window_set), sorted(aggs_set), sorted(codes_set), num_features

def _load_dynamic_shard_from_file(self, path: Path, idx: int) -> sp.csc_matrix:
"""Load a sparse shard into memory.
Args:
Expand Down Expand Up @@ -157,32 +162,43 @@ def _load_dynamic_shard_from_file(self, path: Path) -> sp.csc_matrix:
... iterator_instance = Iterator(cfg)
... iterator_instance.codes_mask = np.array([False, True, True])
... loaded_shard = iterator_instance._load_dynamic_shard_from_file(sample_shard_path)
... assert isinstance(loaded_shard, sp.csc_matrix)
... expected_csc = sp.csc_matrix(sample_filtered_data)
... assert isinstance(loaded_shard, sp.csr_matrix)
... expected_csr = sp.csr_matrix(sample_filtered_data)
... assert sp.issparse(loaded_shard)
... assert np.array_equal(loaded_shard.data, expected_csc.data)
... assert np.array_equal(loaded_shard.indices, expected_csc.indices)
... assert np.array_equal(loaded_shard.indptr, expected_csc.indptr)
... assert np.array_equal(loaded_shard.data, expected_csr.data)
... assert np.array_equal(loaded_shard.indices, expected_csr.indices)
... assert np.array_equal(loaded_shard.indptr, expected_csr.indptr)
"""
# column_shard is of form event_idx, feature_idx, value
column_shard = np.load(path).T # TODO: Fix this!!!

shard = sp.csc_matrix(
(column_shard[:, 0], (column_shard[:, 1], column_shard[:, 2])),
shape=(
max(column_shard[:, 1].astype(np.int32) + 1),
max(column_shard[:, 2].astype(np.int32) + 1),
max(self.valid_event_ids[self._data_shards[idx]], column_shard[:, 1]) + 1,
self.num_features,
),
)
return self._filter_shard_on_codes_and_freqs(shard)

def _get_dynamic_shard_by_index(self, idx: int) -> sp.csr_matrix:
"""Load a specific shard of dynamic data from disk and return it as a sparse matrix after filtering
column inclusion."""
column inclusion.
files = list(self.dynamic_data_path.glob(f"*/*/*/{self._data_shards[idx]}.npy"))
files = sorted([file for file in files if self._filter_shard_files_on_window_and_aggs(file)])
dynamic_cscs = [self._load_dynamic_shard_from_file(file) for file in files]
return sp.hstack(dynamic_cscs).tocsr()[self._valid_event_ids[idx], :]
Args:
- idx (int): Index of the shard to load.
Returns:
- sp.csr_matrix: Filtered sparse matrix.
"""
shard_name = self._data_shards[idx]
shard_pattern = f"*/*/*/{shard_name}.npy"
files = self.dynamic_data_path.glob(shard_pattern)
valid_files = sorted(file for file in files if self._filter_shard_files_on_window_and_aggs(file))
dynamic_csrs = [self._load_dynamic_shard_from_file(file, idx) for file in valid_files]
combined_csr = sp.hstack(dynamic_csrs, format="csr")
valid_indices = self.valid_event_ids[shard_name]
return combined_csr[valid_indices, :]

def _get_shard_by_index(self, idx: int) -> tuple[sp.csr_matrix, np.ndarray]:
"""Load a specific shard of data from disk and concatenate with static data.
Expand All @@ -198,22 +214,25 @@ def _get_shard_by_index(self, idx: int) -> tuple[sp.csr_matrix, np.ndarray]:
dynamic_df = self._get_dynamic_shard_by_index(idx)
logger.debug(f"Dynamic data loading took {datetime.now() - time}")
time = datetime.now()
label_df = self._get_label_by_index(idx)
label_df = self.labels[self._data_shards[idx]]
logger.debug(f"Task data loading took {datetime.now() - time}")

return dynamic_df, label_df["label"].values
return dynamic_df, label_df

def _filter_shard_files_on_window_and_aggs(self, file: Path) -> bool:
parts = file.relative_to(self.dynamic_data_path).parts
if not parts:
if len(parts) < 2:
return False

windows_part = parts[0]
aggs_part = "/".join(parts[1:-1])

return (self.window_set is None or windows_part in self.window_set) and (
self.aggs_set is None or aggs_part in self.aggs_set
)
if self.window_set is not None and windows_part not in self.window_set:
return False

if self.aggs_set is not None and aggs_part not in self.aggs_set:
return False

return True

def _filter_shard_on_codes_and_freqs(self, df: sp.csc_matrix) -> sp.csc_matrix:
"""Filter the dynamic data frame based on the inclusion sets. Given the codes_mask, filter the data
Expand All @@ -227,7 +246,7 @@ def _filter_shard_on_codes_and_freqs(self, df: sp.csc_matrix) -> sp.csc_matrix:
"""
if self.codes_set is None:
return df
return df[:, list({index for index in self.codes_set if index < df.shape[1]})]
return df[:, self.codes_set] # [:, list({index for index in self.codes_set if index < df.shape[1]})]

def next(self, input_data: Callable):
"""Advance the iterator by 1 step and pass the data to XGBoost. This function is called by XGBoost
Expand Down Expand Up @@ -358,8 +377,12 @@ def xgboost(cfg: DictConfig) -> float:
Returns:
- float: Evaluation result.
"""
logger.debug("Initializing XGBoost model")
model = XGBoostModel(cfg)
logger.debug("Training XGBoost model")
time = datetime.now()
model.train()
logger.debug(f"Training took {datetime.now() - time}")
# save model
save_dir = (
Path(cfg.model_dir)
Expand All @@ -369,8 +392,14 @@ def xgboost(cfg: DictConfig) -> float:
save_dir.mkdir(parents=True, exist_ok=True)

model.model.save_model(save_dir / f"{np.random.randint(100000, 999999)}_model.json")

return model.evaluate()


if __name__ == "__main__":
xgboost()
# start_time = datetime.now()
# xgboost()
# logger.debug(f"Total time: {datetime.now() - start_time}")
num = 10
time = timeit(xgboost, number=num) / num
logger.debug(f"Training time averaged over {num} runs: {time}")

0 comments on commit 85f38b5

Please sign in to comment.