From 95f5694cba920fedf6473d1525936d6c4cc8777a Mon Sep 17 00:00:00 2001 From: Teya Bergamaschi Date: Mon, 3 Jun 2024 19:22:27 +0000 Subject: [PATCH] updates for task cached shard --- scripts/xgboost.py | 38 +++++++++++--------------------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/scripts/xgboost.py b/scripts/xgboost.py index abeeede..c5b5f15 100644 --- a/scripts/xgboost.py +++ b/scripts/xgboost.py @@ -200,7 +200,6 @@ def _get_dynamic_shard_by_index(self, idx: int) -> sp.csc_matrix: if not all(file.exists() for file in files): raise ValueError(f"Not all files exist for shard {self._data_shards[idx]}") - shard_name = self._data_shards[idx] dynamic_cscs = [self._load_dynamic_shard_from_file(file, idx) for file in files] fn_name = "_get_dynamic_shard_by_index" @@ -208,14 +207,14 @@ def _get_dynamic_shard_by_index(self, idx: int) -> sp.csc_matrix: self._register_start(key=hstack_key) combined_csc = sp.hstack(dynamic_cscs, format="csc") # TODO: check this - self._register_end(key=hstack_key) - # Filter Rows - valid_indices = self.valid_event_ids[shard_name] - filter_key = f"{fn_name}/filter" - self._register_start(key=filter_key) - out = combined_csc[valid_indices, :] - self._register_end(key=filter_key) - return out + # self._register_end(key=hstack_key) + # # Filter Rows + # valid_indices = self.valid_event_ids[shard_name] + # filter_key = f"{fn_name}/filter" + # self._register_start(key=filter_key) + # out = combined_csc[valid_indices, :] + # self._register_end(key=filter_key) + return combined_csc @TimeableMixin.TimeAs def _get_shard_by_index(self, idx: int) -> tuple[sp.csc_matrix, np.ndarray]: @@ -273,7 +272,7 @@ def next(self, input_data: Callable): # input_data is a function passed in by XGBoost who has the exact same signature of # ``DMatrix`` X, y = self._get_shard_by_index(self._it) # self._data_shards[self._it]) - input_data(data=X, label=y) + input_data(data=sp.csr_matrix(X), label=y) self._it += 1 # Return 1 to let XGBoost know we haven't seen all the files yet. return 1 @@ -295,7 +294,7 @@ def collect_in_memory(self) -> tuple[sp.csc_matrix, np.ndarray]: A tuple where the first element is a sparse matrix containing the feature data, and the second element is a numpy array containing the labels. """ - # TODO: Make this more efficient especially if it is in csc format already + X = [] y = [] for i in range(len(self._data_shards)): @@ -329,22 +328,6 @@ def __init__(self, cfg: DictConfig): self.model = None - # @TimeableMixin.TimeAs - # def _get_callbacks(self): - # """Get the callbacks for training.""" - # callbacks = [] - # if self.cfg.model.early_stopping_rounds is not None: - # es = xgb.callback.EarlyStopping( - # rounds=self.cfg.model.early_stopping_rounds, - # min_delta=1e-3, - # save_best=True, - # maximize=True, - # data_name="tuning", - # metric_name="auc", - # ) - # callbacks.append(es) - # return callbacks - @TimeableMixin.TimeAs def _train(self): """Train the model.""" @@ -353,6 +336,7 @@ def _train(self): self.dtrain, num_boost_round=self.cfg.num_boost_round, early_stopping_rounds=self.cfg.early_stopping_rounds, + # nthreads=self.cfg.nthreads, evals=[(self.dtrain, "train"), (self.dtuning, "tuning")], )