Skip to content

Commit

Permalink
added implementation description with file path information
Browse files Browse the repository at this point in the history
  • Loading branch information
Oufattole committed Oct 22, 2024
1 parent 53d136e commit d2ce1ad
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 91 deletions.
246 changes: 191 additions & 55 deletions docs/implementation.md
Original file line number Diff line number Diff line change
@@ -1,105 +1,241 @@
# The MEDS-Tab Architecture

In this section, we describe the MEDS-Tab architecture, specifically some of the pipeline choices we made to reduce memory usage and increase speed during the tabularization process and XGBoost tuning process.
MEDS-Tab is designed to address two key challenges in healthcare machine learning: (1) efficiently tabularizing large-scale electronic health record (EHR) data and (2) training competitive baseline models on this tabularized data. This document outlines the architecture and implementation details of MEDS-Tab's pipeline.

We break our method into 4 discrete parts:
## Overview

1. Describe codes (compute feature frequencies)
2. Tabularization of time-series data
3. Efficient data caching for task-specific rows
4. XGBoost training
The MEDS-Tab pipeline consists of six main stages, with the first (stage 0) being optional:

## 1. Describe Codes (compute feature frequencies)
0. Data Resharding (Optional)
1. Data Description (Code Frequency Analysis)
2. Static Data Tabularization
3. Time-Series Data Tabularization
4. Task-Specific Data Caching
5. Model Training

This initial stage processes a pre-shareded dataset. We expect a structure as follows where each shard contains a subset of the patients:
Each stage is designed with scalability and efficiency in mind, using sparse matrix operations and data sharding to handle large-scale medical datasets.

## Stage 0: Data Resharding (Optional)

This optional preliminary stage helps optimize data processing by restructuring the input data into manageable shards. Resharding is particularly useful when dealing with large datasets or when experiencing memory constraints. The process uses the MEDS_transform-reshard_to_split command and supports parallel processing via Hydra's joblib launcher, with configurable shard sizes based on number of subjects.

Consider resharding if you're experiencing memory issues in later stages, need to process very large datasets, want to enable efficient parallel processing, or have uneven distribution of data across existing shards.

### Output Structure
```text
/PATH/TO/MEDS/DATA
/PATH/TO/MEDS_RESHARD_DIR
└─── <SPLIT A>
│ │ <SHARD 0>.parquet
│ │ <SHARD 1>.parquet
│ │ ...
└─── <SPLIT B>
│ <SHARD 0>.parquet
│ <SHARD 1>.parquet
│ ...
```

## Stage 1: Data Description

The first stage analyzes the MEDS data to compute code frequencies and categorize features. This information is crucial for subsequent feature selection and optimization. The implementation iterates through data shards to compute feature frequencies and categorizes codes into dynamic codes (codes with timestamps), dynamic numeric values (codes with timestamps and numerical values), static codes (codes without timestamps), and static numeric values (codes without timestamps but with numerical values). Results are stored in a `${output_dir}/metadata/codes.parquet` file for use in subsequent stages, where `output_dir` is a key word argument.

### Input Data Structure
```text
/PATH/TO/MEDS/DATA
└─── <SPLIT A>
│ │ <SHARD 0>.parquet
│ │ <SHARD 1>.parquet
| │ ...
|
...
│ │ ...
└─── <SPLIT B>
│ <SHARD 0>.parquet
│ <SHARD 1>.parquet
│ ...
```

We then compute and store feature counts, crucial for determining which features are relevant for further analysis.
## Stage 2: Static Data Tabularization

**Detailed Workflow:**
This stage processes static patient data (data without timestamps) into a format suitable for modeling. The implementation uses a dense pivot operations which because static data is generally relatively small. Then this stage converts the data to a sparse matrix format for consistency with time-series data. At first there is a single row for each `subject_id` with their static data. This is are duplicated by the number of unique times the patient has data to align with time-series events, and processing over shards is performed serially due to the manageable size of static data.

- **Data Loading and Sharding**: We iterate through shards to compute feature frequencies for each shard.
- **Count Aggregation**: After computing feature counts across shards, we aggregate them to get a final count of each feature across the entire dataset training dataset, which allows us to filter out infrequent features in the tabularization stage or when tuning XGBoost.
### Input Data Structure
```text
/PATH/TO/MEDS/DATA
└─── <SPLIT A>
│ │ <SHARD 0>.parquet
│ │ <SHARD 1>.parquet
│ │ ...
└─── <SPLIT B>
│ <SHARD 0>.parquet
│ <SHARD 1>.parquet
│ ...
```

## 2. Tabularization of Time-Series Data
### Output Data Structure
```text
${output_dir}/tabularize/
└─── <SPLIT A>
│ │ <SHARD 0>/none/static/present.npz
│ │ <SHARD 0>/none/static/first.npz
│ │ <SHARD 1>/none/static/present.npz
│ │ ...
└─── <SPLIT B>
│ <SHARD 0>/none/static/present.npz
│ <SHARD 0>/none/static/first.npz
│ <SHARD 1>/none/static/present.npz
│ ...
```

### Overview
Note that `.../none/static/present.npz` represents the tabularized data for static features with the aggregation method `static/present`. The `.../none/static/first.npz` represents the tabularized data for static features with the aggregation method `static/first`.

The tabularization stage of our pipeline, exposed via the cli commands:
## Stage 3: Time-Series Data Tabularization

- `meds-tab-tabularize-static` for tabularizing static data
- and `meds-tab-tabularize-time-series` for tabularizing the time series data
This stage handles the computationally intensive task of converting temporal medical data into feature vectors. The process employs several key optimizations: sparse matrix operations utilizing scipy.sparse for memory-efficient storage of sparse non-zero elements, data sharding that processes data in patient-based shards and enables parallel processing, and efficient aggregation using Polars for fast rolling window computations.

Static data is relatively small in the medical datasets, so we use a dense pivot operation, convert it to a sparse matrix, and then duplicate rows such that the static data will match up with the time series data rows generated in the next step. Static data is currently processed serially.
The process flow begins by loading shard data into a Polars DataFrame, converting it to sparse matrix format where rows represent events and columns represent features. It then aggregates same-day events per patient, applies rolling window aggregations, and stores results in sparse coordinate format (.npz files).

The script for tabularizing time series data primarily transforms a raw, unstructured dataset into a structured, feature-rich dataset by utilizing a series of sophisticated data processing steps. This transformation (as depicted in the figure below) involves converting raw time series from a Polars dataframe into a sparse matrix format, aggregating events that occur at the same date for the same patient, and then applying rolling window aggregations to extract temporal features.
### Input Data Structure
```text
/PATH/TO/MEDS/DATA
└─── <SPLIT A>
│ │ <SHARD 0>.parquet
│ │ <SHARD 1>.parquet
│ │ ...
└─── <SPLIT B>
│ <SHARD 0>.parquet
│ <SHARD 1>.parquet
│ ...
```

![Time Series Tabularization Method](../assets/pivot.png)
### Output Data Structure
```text
${output_dir}/tabularize/
└─── <SPLIT A>
│ │ <SHARD 0>/1d/code/count.npz
│ │ <SHARD 0>/1d/value/sum.npz
| | ...
| | <SHARD 0>/7d/code/count.npz
│ │ <SHARD 0>/7d/value/sum.npz
│ │ ...
| | <SHARD 1>/1d/code/count.npz
│ │ <SHARD 1>/1d/value/sum.npz
│ │ ...
└─── <SPLIT B>
│ ...
```

### High-Level Tabularization Algorithm
The output structure consists of a directory for each split, containing subdirectories for each shard. Each shard subdirectory contains subdirectories for each aggregation method and window size, with the final output files stored in sparse coordinate format (.npz). In this example we have shown the output for the `1d` and `7d` window sizes and `code/count` and `value/sum` aggregation methods.

1. **Data Loading and Categorization**:
## Stage 4: Task-Specific Data Caching

- The script iterates through shards of patients, and shards can be processed in parallel using hydras joblib to launch multiple processes.
This stage aligns tabularized data with specific prediction tasks, optimizing for efficient model training. The implementation accepts task labels following the MEDS label-schema and matches them with nearest prior feature vectors. It filters tabularized data to include only task-relevant events while maintaining sparse format for efficient storage. Labels must include subject_id, prediction_time, and boolean_value for binary classification.

2. **Sparse Matrix Conversion**:

- Data from the Polars dataframe is converted into a sparse matrix format, where each row represents a unique event (patient x timestamp), and each column corresponds to a MEDS code for the patient.
### Input Data Structure
```text
${output_dir}/tabularize/ # Output from Stage 2 and 3
${input_label_dir}/**/*.parquet # All parquet files in the `input_label_dir` are used as labels
```

3. **Rolling Window Aggregation**:

- For each aggregation method (sum, count, min, max, etc.), events that occur on the same date for the same patient are aggregated. This reduces the amount of data we have to perform rolling windows over.
- Then we aggregate features over the specified rolling windows sizes.
### Output Data Structure

4. **Output Storage**:
Labels are cached in:
```text
$output_label_cache_dir
└─── <SPLIT A>
│ │ <SHARD 0>.parquet
│ │ <SHARD 1>.parquet
│ │ ...
└─── <SPLIT B>
│ <SHARD 0>.parquet
│ <SHARD 1>.parquet
│ ...
```

- Sparse array is converted to Coordinate List format and stored as a `.npz` file on disk.
- The file paths look as follows
For each shard, the labels are stored in a parquet file with the same name as the shard. The labels are stored in the `output_label_cache_dir` directory which by default is relative to the key word argument `$output_dir`: `output_label_cache_dir = ${output_dir}/${task_name}/labels`.

Task specific tabularized data is cached in the following format:
```text
/PATH/TO/MEDS/TABULAR_DATA
$output_tabularized_cache_dir
└─── <SPLIT A>
├─── <SHARD 0>
│ ├───code
│ │ └───count.npz
│ └───value
│ └───sum.npz
...
│ │ <SHARD 0>/1d/code/count.npz
│ │ <SHARD 0>/1d/value/sum.npz
| | <SHARD 0>/none/static/present.npz
| | <SHARD 0>/none/static/first.npz
| | ...
| | <SHARD 0>/7d/code/count.npz
│ │ <SHARD 0>/7d/value/sum.npz
│ │ ...
| | <SHARD 1>/1d/code/count.npz
│ │ <SHARD 1>/1d/value/sum.npz
│ │ <SHARD 1>/none/static/present.npz
| | <SHARD 1>/none/static/first.npz
│ │ ...
└─── <SPLIT B>
│ ...
```
The output structure is identical to the structure in Stages 2 and 3, but where we filter rows in the sparse matrix to only include events relevant to the task. This is done by selecting one row for each label that corresponds with the nearest prior event. The task-specific tabularized data is stored in the `output_tabularized_cache_dir` directory. By default this directory is relative to the key word argument `$output_dir`: `output_tabularized_cache_dir = ${output_dir}/${task_name}/task_cache`.

## 3. Efficient Data Caching for Task-Specific Rows
## Stage 5: Model Training

Now that we have generated tabular features for all the events in our dataset, we can cache subsets relevant for each task we wish to train a supervised model on. This step is critical for efficiently training machine learning models on task-specific data without having to load the entire dataset.
The final stage provides efficient model training capabilities, particularly optimized for XGBoost. The system incorporates extended memory support through sequential shard loading during training and efficient data loading through custom iterators. AutoML integration uses Optuna for hyperparameter optimization, tuning across model parameters, aggregation methods, window sizes, and feature selection thresholds.

**Detailed Workflow:**
### Input Data Structure
```text
# Location of task, split, and shard specific tabularized data
${input_tabularized_cache_dir} # Output from Stage 4
# Location of task, split, and shard specific label data
${input_label_cache_dir} # Output from Stage 4
```

- **Row Selection Based on Tasks**: Only the data rows that are relevant to the specific tasks are selected and cached. This reduces the memory footprint and speeds up the training process.
- **Use of Sparse Matrices for Efficient Storage**: Sparse matrices are again employed here to store the selected data efficiently, ensuring that only non-zero data points are kept in memory, thus optimizing both storage and retrieval times.
### Output Data Structure

The file structure for the cached data mirrors that of the tabular data, also consisting of `.npz` files, where users must specify the directory that stores labels. Labels must follow the [MEDS label-schema](https://github.com/Medical-Event-Data-Standard/meds?tab=readme-ov-file#the-label-schema), specifically including the `subject_id`, `prediction_time`, and `boolean_value` columns which are necessary for binary classification tasks.
For single runs, the output structure is as follows:
```text
# Where to output the model and cached data
time_output_model_dir = ${output_model_dir}/${now:%Y-%m-%d_%H-%M-%S}
├── config.log
├── performance.log
└── xgboost.json # model weights
```

For `multirun` optuna hyperparameter sweeps we get the following output structure:
```text
# Where to output the model and cached data
time_output_model_dir = ${output_model_dir}/${now:%Y-%m-%d_%H-%M-%S}
├── best_trial
| ├── config.log
| ├── performance.log
| └── xgboost.json # model weights
├── hydra
| └── optimization_results.yaml # contains the optimal trial hyperparameters and performance
└── sweep_results # This folder contains raw results for every hyperparameter trial
└── <TRIAL_1_ID>
├── config.log # model config log
├── performance.log # model performance log
└── xgboost.json # model weights
└── <TRIAL_2_ID>
...
```

## 4. XGBoost Training
`output_model_dir` is a keyword argument that specifies the directory where the model and cached data are stored. By default, we append the current date and time to the directory name to avoid overwriting previous runs, and use the `time_output_model_dir` variable to store the full path. If you use a different `model_launcher` than XGBoost, the model weights file will be named accordingly for that model (and will be a `.pkl` file instead of a `json`).

The final stage uses the processed and cached data to train an XGBoost model. This stage is optimized to handle the sparse data structures produced in earlier stages efficiently.
### Supported Models and Processing Options
The default model is XGBoost, with additional options including KNN Classifier, Logistic Regression, Random Forest Classifier, SGD Classifier, and experimental AutoGluon support. Data processing options include sparse-preserving normalization (standard_scaler, max_abs_scaler) and imputation methods that convert to dense format (mean_imputer, median_imputer, mode_imputer). By default no normalization is applied and missing values are treated as missing by `xgboost` or as zero by other models.

**Detailed Workflow:**
## Additional Considerations

- **Iterator for Data Loading**: Custom iterators are designed to load sparse matrices efficiently into the XGBoost training process, which can handle sparse inputs natively, thus maintaining high computational efficiency.
- **Training and Validation**: The model is trained using the tabular data, with evaluation steps that include early stopping to prevent overfitting and tuning of hyperparameters based on validation performance.
- **Hyperaparameter Tuning**: We use [optuna](https://optuna.org/) to tune over XGBoost model pramters, aggregations, window sizes, and the minimimum code inclusion count.
The architecture emphasizes robust memory management through sparse matrices and efficient data sharding, while supporting parallel processing and handling of high-dimensional feature spaces. The system is optimized for performance, minimizing memory footprint and computational overhead while enabling processing of datasets with hundreds of millions of events and tens of thousands of unique medical codes.
Loading

0 comments on commit d2ce1ad

Please sign in to comment.