Skip to content

Commit

Permalink
updates for task cached shard
Browse files Browse the repository at this point in the history
  • Loading branch information
teyaberg committed Jun 3, 2024
1 parent 425b79d commit 95f5694
Showing 1 changed file with 11 additions and 27 deletions.
38 changes: 11 additions & 27 deletions scripts/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,22 +200,21 @@ def _get_dynamic_shard_by_index(self, idx: int) -> sp.csc_matrix:
if not all(file.exists() for file in files):
raise ValueError(f"Not all files exist for shard {self._data_shards[idx]}")

shard_name = self._data_shards[idx]
dynamic_cscs = [self._load_dynamic_shard_from_file(file, idx) for file in files]

fn_name = "_get_dynamic_shard_by_index"
hstack_key = f"{fn_name}/hstack"
self._register_start(key=hstack_key)

combined_csc = sp.hstack(dynamic_cscs, format="csc") # TODO: check this
self._register_end(key=hstack_key)
# Filter Rows
valid_indices = self.valid_event_ids[shard_name]
filter_key = f"{fn_name}/filter"
self._register_start(key=filter_key)
out = combined_csc[valid_indices, :]
self._register_end(key=filter_key)
return out
# self._register_end(key=hstack_key)
# # Filter Rows
# valid_indices = self.valid_event_ids[shard_name]
# filter_key = f"{fn_name}/filter"
# self._register_start(key=filter_key)
# out = combined_csc[valid_indices, :]
# self._register_end(key=filter_key)
return combined_csc

@TimeableMixin.TimeAs
def _get_shard_by_index(self, idx: int) -> tuple[sp.csc_matrix, np.ndarray]:
Expand Down Expand Up @@ -273,7 +272,7 @@ def next(self, input_data: Callable):
# input_data is a function passed in by XGBoost who has the exact same signature of
# ``DMatrix``
X, y = self._get_shard_by_index(self._it) # self._data_shards[self._it])
input_data(data=X, label=y)
input_data(data=sp.csr_matrix(X), label=y)
self._it += 1
# Return 1 to let XGBoost know we haven't seen all the files yet.
return 1
Expand All @@ -295,7 +294,7 @@ def collect_in_memory(self) -> tuple[sp.csc_matrix, np.ndarray]:
A tuple where the first element is a sparse matrix containing the
feature data, and the second element is a numpy array containing the labels.
"""
# TODO: Make this more efficient especially if it is in csc format already

X = []
y = []
for i in range(len(self._data_shards)):
Expand Down Expand Up @@ -329,22 +328,6 @@ def __init__(self, cfg: DictConfig):

self.model = None

# @TimeableMixin.TimeAs
# def _get_callbacks(self):
# """Get the callbacks for training."""
# callbacks = []
# if self.cfg.model.early_stopping_rounds is not None:
# es = xgb.callback.EarlyStopping(
# rounds=self.cfg.model.early_stopping_rounds,
# min_delta=1e-3,
# save_best=True,
# maximize=True,
# data_name="tuning",
# metric_name="auc",
# )
# callbacks.append(es)
# return callbacks

@TimeableMixin.TimeAs
def _train(self):
"""Train the model."""
Expand All @@ -353,6 +336,7 @@ def _train(self):
self.dtrain,
num_boost_round=self.cfg.num_boost_round,
early_stopping_rounds=self.cfg.early_stopping_rounds,
# nthreads=self.cfg.nthreads,
evals=[(self.dtrain, "train"), (self.dtuning, "tuning")],
)

Expand Down

0 comments on commit 95f5694

Please sign in to comment.