diff --git a/scripts/xgboost_sweep.py b/scripts/xgboost_sweep.py index 9172f3a..f8f74b4 100644 --- a/scripts/xgboost_sweep.py +++ b/scripts/xgboost_sweep.py @@ -27,7 +27,7 @@ def __init__(self, cfg: DictConfig, split: str = "train"): self.static_data_path = self.data_path / "static" / split self._data_shards = [shard.stem for shard in list(self.static_data_path.glob("*.parquet"))] if cfg.iterator.keep_static_data_in_memory: - self._static_shards = self._get_static_shards() + self._static_shards = self._collect_static_shards_in_memory() self.codes_set, self.aggs_set, self.min_frequency_set, self.window_set = self._get_inclusion_sets() @@ -46,32 +46,26 @@ def _get_inclusion_sets(self) -> tuple[set, set, set]: codes_set = None aggs_set = None min_frequency_set = None + window_set = None + if self.cfg.codes is not None: codes_set = set(self.cfg.codes) + if self.cfg.aggs is not None: aggs_set = set(self.cfg.aggs) + if self.cfg.min_code_inclusion_frequency is not None: with open(self.data_path / "feature_freqs.json") as f: feature_freqs = json.load(f) min_frequency_set = { key for key, value in feature_freqs.items() if value >= self.cfg.min_code_inclusion_frequency } - window_set = set(self.cfg.window_sizes) + if self.cfg.window_sizes is not None: + window_set = set(self.cfg.window_sizes) return codes_set, aggs_set, min_frequency_set, window_set - def _load_static_shard_by_index(self, idx: int) -> sp.csc_matrix: - """Load a static shard into memory. - - Args: - - idx (int): Index of the shard to load. - - Returns: - - sp.csc_matrix: Sparse matrix with the static shard. - """ - return pd.read_parquet(self.static_data_path / f"{self._data_shards[int(idx)]}.parquet") - - def _get_static_shards(self) -> dict: + def _collect_static_shards_in_memory(self) -> dict: """Load static shards into memory. Returns: @@ -82,59 +76,16 @@ def _get_static_shards(self) -> dict: static_shards[iter] = self._load_static_shard_by_index(iter) return static_shards - def _sparsify_shard(self, df: pd.DataFrame) -> tuple[sp.csc_matrix, np.ndarray]: - """Make X and y as scipy sparse arrays for XGBoost. - - Args: - - df (pandas.DataFrame): Data frame to sparsify. - - Returns: - - tuple[scipy.sparse.csr_matrix, numpy.ndarray]: Tuple of feature data and labels. - """ - labels = df.loc[:, [col for col in df.columns if col.endswith("/task")]] - data = df.drop(columns=labels.columns) - for col in data.columns: - if not isinstance(data[col].dtype, pd.SparseDtype): - data[col] = pd.arrays.SparseArray(data[col]) - sparse_matrix = data.sparse.to_coo() - return csr_matrix(sparse_matrix), labels.values - - def _validate_shard_file_inclusion(self, file: Path) -> bool: - parts = file.relative_to(self.dynamic_data_path).parts - if not parts: - return False - - windows_part = parts[0] - aggs_part = "/".join(parts[1:-1]) - - return (self.window_set is None or windows_part in self.window_set) and ( - self.aggs_set is None or aggs_part in self.aggs_set - ) - - def _assert_correct_sorting(self, shard: pd.DataFrame): - """Assert that the shard is sorted correctly.""" - if "timestamp" in shard.columns: - sort_columns = ["patient_id", "timestamp"] - else: - sort_columns = ["patient_id"] - assert shard[sort_columns].equals(shard[sort_columns].sort_values(by=sort_columns)), ( - "Shard is not sorted on correctly. " - "Please ensure that the data is sorted on patient_id and timestamp, if applicable." - ) - - def _get_sparse_dynamic_shard_from_file(self, path: Path) -> pd.DataFrame: - """Load a sparse shard into memory. This returns a shard as a pandas dataframe, asserted that it is - sorted on patient id and timestamp, if included. + def _load_static_shard_by_index(self, idx: int) -> sp.csc_matrix: + """Load a static shard into memory. Args: - - path (Path): Path to the sparse shard. + - idx (int): Index of the shard to load. Returns: - - pd.DataFrame: Data frame with the sparse shard. + - sp.csc_matrix: Sparse matrix with the static shard. """ - shard = pd.read_pickle(path) - self._assert_correct_sorting(shard) - return shard.drop(columns=["patient_id", "timestamp"]) + return pd.read_parquet(self.static_data_path / f"{self._data_shards[int(idx)]}.parquet") def _get_static_shard_by_index(self, idx: int) -> pd.DataFrame: """Get the static shard from memory or disk. @@ -150,7 +101,7 @@ def _get_static_shard_by_index(self, idx: int) -> pd.DataFrame: else: return self._load_static_shard_by_index(self._data_shards[idx]) - def _get_task(self, idx: int) -> pd.DataFrame: + def _get_task_by_index(self, idx: int) -> pd.DataFrame: """Get the task data for a specific shard. Args: @@ -165,29 +116,32 @@ def _get_task(self, idx: int) -> pd.DataFrame: shard["label"] = np.random.randint(0, 2, shard.shape[0]) return shard[["patient_id", "timestamp", "label"]] - def _filter_df(self, df: pd.DataFrame) -> pd.DataFrame: - """Filter the dynamic data frame based on the inclusion sets. + def _load_dynamic_shard_from_file(self, path: Path) -> pd.DataFrame: + """Load a sparse shard into memory. This returns a shard as a pandas dataframe, asserted that it is + sorted on patient id and timestamp, if included. Args: - - df (pd.DataFrame): Data frame to filter. + - path (Path): Path to the sparse shard. Returns: - - pd.DataFrame: Filtered data frame. + - pd.DataFrame: Data frame with the sparse shard. """ - code_parts = ["/".join(col.split("/")[1:-2]) for col in df.columns] - frequency_parts = ["/".join(col.split("/")[1:-1]) for col in df.columns] + shard = pd.read_pickle(path) + self._assert_correct_sorting(shard) + return shard.drop(columns=["patient_id", "timestamp"]) - filtered_columns = [ - col - for col, code_part, freq_part in zip(df.columns, code_parts, frequency_parts) - if (self.codes_set is None or code_part in self.codes_set) - and (self.min_frequency_set is None or freq_part in self.min_frequency_set) - ] - filtered_columns.extend([col for col in df.columns if col.endswith("/task")]) + def _get_dynamic_shard_by_index(self, idx: int) -> tuple[sp.csr_matrix, np.ndarray]: + """Load a specific shard of dynamic data from disk and return it as a sparse matrix after filtering + column inclusion.""" + files = list(self.dynamic_data_path.glob(f"*/*/*/{self._data_shards[idx]}.pkl")) - return df[filtered_columns] + files = [file for file in files if self._filter_shard_files_on_window_and_aggs(file)] - def _load_dynamic_shard_by_index(self, idx: int) -> tuple[sp.csr_matrix, np.ndarray]: + dynamic_dfs = [self._load_dynamic_shard_from_file(file) for file in files] + dynamic_df = pd.concat(dynamic_dfs, axis=1) + return self._filter_shard_on_codes_and_freqs(dynamic_df) + + def _get_shard_by_index(self, idx: int) -> tuple[sp.csr_matrix, np.ndarray]: """Load a specific shard of data from disk and concatenate with static data. Args: @@ -198,28 +152,84 @@ def _load_dynamic_shard_by_index(self, idx: int) -> tuple[sp.csr_matrix, np.ndar - y (numpy.ndarray): Labels. """ - files = list(self.dynamic_data_path.glob(f"*/*/*/{self._data_shards[idx]}.pkl")) - - files = [file for file in files if self._validate_shard_file_inclusion(file)] - - dynamic_dfs = [self._get_sparse_dynamic_shard_from_file(file) for file in files] - dynamic_df = pd.concat(dynamic_dfs, axis=1) - dynamic_df = self._filter_df(dynamic_df) + dynamic_df = self._get_dynamic_shard_by_index(idx) # TODO: add in some type checking etc for safety static_df = self._get_static_shard_by_index(idx) - task_df = self._get_task(idx) + task_df = self._get_task_by_index(idx) task_df = task_df.rename( columns={col: f"{col}/task" for col in task_df.columns if col not in ["patient_id", "timestamp"]} ) df = pd.merge(task_df, static_df, on=["patient_id"], how="left") self._assert_correct_sorting(df) - df = self._filter_df(df) + df = self._filter_shard_on_codes_and_freqs(df) df = pd.concat([df, dynamic_df], axis=1) return self._sparsify_shard(df) + def _sparsify_shard(self, df: pd.DataFrame) -> tuple[sp.csc_matrix, np.ndarray]: + """Make X and y as scipy sparse arrays for XGBoost. + + Args: + - df (pandas.DataFrame): Data frame to sparsify. + + Returns: + - tuple[scipy.sparse.csr_matrix, numpy.ndarray]: Tuple of feature data and labels. + """ + labels = df.loc[:, [col for col in df.columns if col.endswith("/task")]] + data = df.drop(columns=labels.columns) + for col in data.columns: + if not isinstance(data[col].dtype, pd.SparseDtype): + data[col] = pd.arrays.SparseArray(data[col]) + sparse_matrix = data.sparse.to_coo() + return csr_matrix(sparse_matrix), labels.values + + def _filter_shard_files_on_window_and_aggs(self, file: Path) -> bool: + parts = file.relative_to(self.dynamic_data_path).parts + if not parts: + return False + + windows_part = parts[0] + aggs_part = "/".join(parts[1:-1]) + + return (self.window_set is None or windows_part in self.window_set) and ( + self.aggs_set is None or aggs_part in self.aggs_set + ) + + def _filter_shard_on_codes_and_freqs(self, df: pd.DataFrame) -> pd.DataFrame: + """Filter the dynamic data frame based on the inclusion sets. + + Args: + - df (pd.DataFrame): Data frame to filter. + + Returns: + - pd.DataFrame: Filtered data frame. + """ + code_parts = ["/".join(col.split("/")[1:-2]) for col in df.columns] + frequency_parts = ["/".join(col.split("/")[1:-1]) for col in df.columns] + + filtered_columns = [ + col + for col, code_part, freq_part in zip(df.columns, code_parts, frequency_parts) + if (self.codes_set is None or code_part in self.codes_set) + and (self.min_frequency_set is None or freq_part in self.min_frequency_set) + ] + filtered_columns.extend([col for col in df.columns if col.endswith("/task")]) + + return df[filtered_columns] + + def _assert_correct_sorting(self, shard: pd.DataFrame): + """Assert that the shard is sorted correctly.""" + if "timestamp" in shard.columns: + sort_columns = ["patient_id", "timestamp"] + else: + sort_columns = ["patient_id"] + assert shard[sort_columns].equals(shard[sort_columns].sort_values(by=sort_columns)), ( + "Shard is not sorted on correctly. " + "Please ensure that the data is sorted on patient_id and timestamp, if applicable." + ) + def next(self, input_data: Callable): """Advance the iterator by 1 step and pass the data to XGBoost. This function is called by XGBoost during the construction of ``DMatrix`` @@ -236,7 +246,7 @@ def next(self, input_data: Callable): # input_data is a function passed in by XGBoost who has the exact same signature of # ``DMatrix`` - X, y = self._load_dynamic_shard_by_index(self._it) # self._data_shards[self._it]) + X, y = self._get_shard_by_index(self._it) # self._data_shards[self._it]) input_data(data=X, label=y) self._it += 1 # Return 1 to let XGBoost know we haven't seen all the files yet. @@ -255,7 +265,7 @@ def collect_in_memory(self) -> tuple[sp.csr_matrix, np.ndarray]: X = [] y = [] for i in range(len(self._data_shards)): - X_, y_ = self._load_dynamic_shard_by_index(i) + X_, y_ = self._get_shard_by_index(i) X.append(X_) y.append(y_) @@ -288,6 +298,8 @@ def __init__(self, cfg: DictConfig): def train(self): """Train the model.""" self._build() + # TODO: add in eval, early stopping, etc. + # TODO: check for Nan and inf in labels! self.model = xgb.train( OmegaConf.to_container(self.cfg.model), self.dtrain ) # do we want eval and things?