-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Nassim Oufattole
committed
Jun 13, 2024
1 parent
85bfd5e
commit f6e229a
Showing
4 changed files
with
433 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,108 @@ | ||
# How does MEDS-Tab Work? | ||
|
||
#### What do you mean "tabular pipelines"? Isn't _all_ structured EHR data already tabular? | ||
|
||
This is a common misconception. _Tabular_ data refers to data that can be organized in a consistent, logical | ||
set of rows/columns such that the entirety of a "sample" or "instance" for modeling or analysis is contained | ||
in a single row, and the set of columns possibly observed (there can be missingness) is consistent across all | ||
rows. Structured EHR data does not satisfy this definition, as we will have different numbers of observations | ||
of medical codes and values at different timestamps for different patients, so it cannot simultanesouly | ||
satisfy the (1) "single row single instance", (2) "consistent set of columns", and (3) "logical" requirements. | ||
Thus, in this pipeline, when we say we will produce a "tabular" view of MEDS data, we mean a dataset that can | ||
realize these constraints, which will explicitly involve summarizing the patient data over various historical | ||
or future windows in time to produce a single row per patient with a consistent, logical set of columns | ||
(though there may still be missingness). | ||
## The MEDS-Tab Architecture | ||
|
||
In this section we describe the MEDS-Tab architecture, specifically some of the pipeline choices we made to reduce memory usage and increase speed during the tabularization process and XGBoost tuning process. | ||
|
||
We break our method into 4 discrete parts | ||
|
||
1. Describe codes (compute feature frequencies) | ||
2. Given time series data tabularize it | ||
3. cache task specific rows of data for efficient loading | ||
4. XGBoost training | ||
|
||
### 1. Describe Codes (Compute Feature Frequencies) | ||
|
||
This initial stage processes a pre-shareded dataset. We expect a structure as follows where each shard contains a subset of the patients: | ||
|
||
``` | ||
/PATH/TO/MEDS/DATA | ||
│ | ||
└───<SPLIT A> | ||
│ │ <SHARD 0>.parquet | ||
│ │ <SHARD 1>.parquet | ||
│ │ ... | ||
│ | ||
└───<SPLIT B> | ||
│ │ <SHARD 0>.parquet | ||
│ │ <SHARD 1>.parquet | ||
| │ ... | ||
| | ||
... | ||
``` | ||
|
||
We then compute and store feature frequencies, crucial for determining which features are relevant for further analysis. | ||
|
||
**Detailed Workflow:** | ||
|
||
- **Data Loading and Sharding**: We iterate through shards to compute feature frequencies for each shard. | ||
- **Frequency Aggregation**: After computing frequencies across shards, we aggregate them to get a final count of each feature across the entire dataset training dataset, which allows us to filter out infrequent features in the tabularization stage or when tuning XGBoost. | ||
|
||
This outputs parquet file \`\` | ||
|
||
### 2. Tabularization of Time Series Data | ||
|
||
### Overview | ||
|
||
The tabularization stage of our pipeline, exposed via the cli commands: | ||
|
||
- `meds-tab-tabularize-static` for tabularizing static data | ||
- and `meds-tab-tabularize-time-series` for tabularizing the time series data | ||
|
||
Static data is relatively small in the medical datasets, so we use a dense pivot operation, convert it to a sparse matrix, and then duplicate rows such that the static data will match up with the time series data rows generated in the next step. Static data is currently processed serially. | ||
|
||
The script for tabularizing time series data primarily transforms a raw, unstructured dataset into a structured, feature-rich dataset by utilizing a series of sophisticated data processing steps. This transformation involves converting raw time series from a Polars dataframe into a sparse matrix format, aggregating events that occur at the same date for the same patient, and then applying rolling window aggregations to extract temporal features. Here's a step-by-step breakdown of the algorithm: | ||
|
||
### High-Level Steps | ||
|
||
1. **Data Loading and Categorization**: | ||
|
||
- The script iterates through shards of patients, and shards can be processed in parallel using hydras joblib to launch multiple processes. | ||
|
||
2. **Sparse Matrix Conversion**: | ||
|
||
- Data from the Polars dataframe is converted into a sparse matrix format. This step is crucial for efficient memory management, especially when dealing with large datasets. | ||
|
||
3. **Event Aggregation**: | ||
|
||
- Events that occur on the same date for the same patient are aggregated. This reduces redundancy in the data and significantly speeds up the rolling window aggregations on datasets that have lots of concurrent observations. | ||
|
||
4. **Rolling Window Aggregation**: | ||
|
||
- The aggregated data undergoes a rolling window operation where various statistical methods are applied (sum, count, min, max, etc.) to extract features over specified window sizes. | ||
|
||
5. **Output Storage**: | ||
|
||
- Sparse array is converted to Coordinate List format and stored as a `.npz` file on disk. | ||
- The file paths look as follows | ||
|
||
``` | ||
/PATH/TO/MEDS/TABULAR_DATA | ||
│ | ||
└───<SPLIT A> | ||
├───<SHARD 0> | ||
│ ├───code | ||
│ │ └───count.npz | ||
│ └───value | ||
│ └───sum.npz | ||
... | ||
``` | ||
|
||
### 3. Efficient Data Caching for Task-Specific Rows | ||
|
||
Now that we have generated tabular features for all the events in our dataset, we can cache subsets relevant for each task we wish to train a supervised model on. This step is critical for efficiently training machine learning models on task-specific data without having to load the entire dataset. | ||
|
||
**Detailed Workflow:** | ||
|
||
- **Row Selection Based on Tasks**: Only the data rows that are relevant to the specific tasks are selected and cached. This reduces the memory footprint and speeds up the training process. | ||
- **Use of Sparse Matrices for Efficient Storage**: Sparse matrices are again employed here to store the selected data efficiently, ensuring that only non-zero data points are kept in memory, thus optimizing both storage and retrieval times. | ||
|
||
The file structure for the cached data mirrors the tabular data and alsi is `.npz` files, and users must specify the directory to labels that follow the same shard filestructure as the input meds data from step (1). Label parquets need `patient_id`, `timestamp`, and `label` columns. | ||
|
||
### 4. XGBoost Training | ||
|
||
The final stage uses the processed and cached data to train an XGBoost model. This stage is optimized to handle the sparse data structures produced in earlier stages efficiently. | ||
|
||
**Detailed Workflow:** | ||
|
||
- **Iterator for Data Loading**: Custom iterators are designed to load sparse matrices efficiently into the XGBoost training process, which can handle sparse inputs natively, thus maintaining high computational efficiency. | ||
- **Training and Validation**: The model is trained using the tabular data, with evaluation steps that include early stopping to prevent overfitting and tuning of hyperparameters based on validation performance. | ||
- **Hyperaparameter Tuning**: We use [optuna](https://optuna.org/) to tune over XGBoost model pramters, aggregations, window_sizes, and the minimimu code inclusion frequency. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.