Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean #5

Merged
merged 122 commits into from
Jun 5, 2024
Merged

Clean #5

merged 122 commits into from
Jun 5, 2024

Conversation

Oufattole
Copy link
Collaborator

@Oufattole Oufattole commented Jun 5, 2024

Merging Functional Pipeline with CLI to Main

Summary by CodeRabbit

  • New Features

    • Introduced configuration files for MEDS tabular autoML processes, including settings for XGBoost models, tabularization, and code frequency metadata generation.
    • Added scripts for training XGBoost models, optimizing hyperparameters, and tabularizing static and time-series data.
    • Enhanced the README with detailed explanations of the repository's utilities and usage.
  • Documentation

    • Updated README with comprehensive details on tabular data, installation, usage, and future TODOs.
  • Chores

    • Updated Python version in CI workflow to 3.12 and adjusted testing commands.
    • Added new authors and updated project description and dependencies in pyproject.toml.
  • Refactor

    • Enhanced pre-commit configuration with autoflake arguments for better code quality.
    • Added numerous helper functions and utilities for data processing, feature generation, and file management.
  • Bug Fixes

    • Improved handling and processing of data frames and sparse matrices for better performance and accuracy.

Oufattole and others added 30 commits May 25, 2024 08:06
…ts is correct when producing static representations
- Create individual scripts for:
  1. Retrieving column names
  2. Generating static representations
  3. Generating dynamic representations
  4. Summarizing data over windows
- Add doctests for retrieving column names and generating static representations
- Confirm functionality of the above tests
…dexed file types:

- Code data: Files with a patient_id and timestamp index, containing columns for each code with binary presence indicators (1 and 0).
- Value data: Files similar in structure but containing the numerical values observed for each code.
…n handling

- Introduce VALID_AGGREGATIONS to define permissible aggregations.
- Implement  to generate dynamic column aliases based on window size and aggregation.
- Extend  for dynamic expression creation based on aggregation type and window size, handling both cumulative and windowed aggregations.
- Enhance  to apply specified aggregations over defined window sizes, ensuring comprehensive data summarization.
- Update  to handle multiple dataframes, aggregate data using specified window sizes and aggregations, and ensure inclusion of all specified feature columns, adding missing ones with default values.
- Add extensive doctests to ensure accuracy of the summarization functions, demonstrating usage with both code and value data types.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
…arse matrices instead of the sparse dataframes for values and codes. Also started modifying the summarization script to work with the sparse dataframes.
…arse pandas array for each patient, now we just use sparse scipy matrices
…rging rows that occur at the same time based on the current aggregation strategy. For example if the aggregation is sum, we sum up all rows on the same date, or if the aggregation is count we count up all rows on the same date.
@Oufattole Oufattole requested a review from mmcdermott June 5, 2024 05:20
Copy link
Contributor

coderabbitai bot commented Jun 5, 2024

Warning

Rate limit exceeded

@Oufattole has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 37 minutes and 59 seconds before requesting another review.

How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

Commits

Files that changed from the base of the PR and between 6240c8a and ec73910.

Walkthrough

The updates enhance the MEDS tabular AutoML system by upgrading Python to version 3.12, refining package installations, and improving test execution commands. Configuration files for various tasks like XGBoost model training, tabularization, and code description are added. Scripts and utility functions are introduced or updated to support these tasks, ensuring efficient data processing, feature generation, and model optimization.

Changes

Files/Directories Change Summary
.github/workflows/tests.yaml Updated Python version to 3.12, modified package installation commands, and adjusted test execution commands.
.pre-commit-config.yaml Added arguments --in-place and --remove-all-unused-imports to autoflake hook.
README.md Clarified concepts related to tabular data, MEDS datasets, and repository utilities.
pyproject.toml Added authors, updated project description, dependencies, scripts, and build system requirements.
src/MEDS_tabular_automl/configs/... Added new configuration files for default settings, code description, XGBoost launch, and tabularization.
src/MEDS_tabular_automl/describe_codes.py Introduced functions for computing feature frequencies and filtering data.
src/MEDS_tabular_automl/file_name.py Added helper functions for managing file names and paths.
src/MEDS_tabular_automl/generate_static_features.py Provided functions for generating static data representations.
src/MEDS_tabular_automl/generate_summarized_reps.py Introduced functions for aggregating sparse matrices and summarizing data frames.
src/MEDS_tabular_automl/generate_ts_features.py Added functions for generating time series features.
src/MEDS_tabular_automl/mapper.py Provided utilities for parallel map operations with caching and locking.
src/MEDS_tabular_automl/scripts/... Added scripts for caching tasks, describing codes, launching XGBoost, optimizing XGBoost, and tabularizing static and time-series data.
src/MEDS_tabular_automl/utils.py Introduced core dataset processing functions and constants.

Sequence Diagram(s) (Beta)

sequenceDiagram
    participant User
    participant Config
    participant Script
    participant Model
    participant Data

    User->>Config: Define settings
    Config->>Script: Pass configurations
    Script->>Data: Load and process data
    Script->>Model: Train model with data
    Model->>User: Provide results
Loading

Poem

In the land of code, a rabbit hops,
Through MEDS and data, it never stops.
Python upgraded, scripts refined,
Features tabular, models aligned.
With every change, it leaps with glee,
In the meadow of AutoML, wild and free! 🐇✨


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@Oufattole Oufattole merged commit dba36ce into main Jun 5, 2024
0 of 2 checks passed
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 16

Outside diff range and nitpick comments (5)
README.md (3)

2-2: Clarify the scope of "limited automatic tabular ML pipelines" to set accurate expectations for users.


32-44: The description of the repository's key pieces is clear and informative. However, consider adding more details about the specific ML models supported in the AutoML pipelines.


75-76: The explanation of the tabularization process is succinct. Consider linking to the data/tabularize.py script directly if it's hosted in the repository for easier access.

src/MEDS_tabular_automl/scripts/launch_xgboost.py (2)

24-47: Ensure proper documentation for all attributes in the class Iterator.

Consider adding detailed descriptions for each attribute in the class docstring to improve code maintainability and readability.


302-321: Ensure that XGBoostModel initializes all its attributes properly.

It's good practice to explicitly initialize all attributes in the constructor to avoid hidden bugs.

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 6240c8a and a19ad3e.

Files selected for processing (26)
  • .github/workflows/tests.yaml (1 hunks)
  • .pre-commit-config.yaml (1 hunks)
  • README.md (3 hunks)
  • pyproject.toml (1 hunks)
  • src/MEDS_tabular_automl/configs/default.yaml (1 hunks)
  • src/MEDS_tabular_automl/configs/describe_codes.yaml (1 hunks)
  • src/MEDS_tabular_automl/configs/launch_xgboost.yaml (1 hunks)
  • src/MEDS_tabular_automl/configs/tabularization.yaml (1 hunks)
  • src/MEDS_tabular_automl/configs/tabularization/default.yaml (1 hunks)
  • src/MEDS_tabular_automl/configs/task_specific_caching.yaml (1 hunks)
  • src/MEDS_tabular_automl/configs/tmp.yaml.yaml (1 hunks)
  • src/MEDS_tabular_automl/describe_codes.py (1 hunks)
  • src/MEDS_tabular_automl/file_name.py (1 hunks)
  • src/MEDS_tabular_automl/generate_static_features.py (1 hunks)
  • src/MEDS_tabular_automl/generate_summarized_reps.py (1 hunks)
  • src/MEDS_tabular_automl/generate_ts_features.py (1 hunks)
  • src/MEDS_tabular_automl/mapper.py (1 hunks)
  • src/MEDS_tabular_automl/scripts/cache_task.py (1 hunks)
  • src/MEDS_tabular_automl/scripts/describe_codes.py (1 hunks)
  • src/MEDS_tabular_automl/scripts/launch_xgboost.py (1 hunks)
  • src/MEDS_tabular_automl/scripts/sweep_xgboost.py (1 hunks)
  • src/MEDS_tabular_automl/scripts/tabularize_static.py (1 hunks)
  • src/MEDS_tabular_automl/scripts/tabularize_time_series.py (1 hunks)
  • src/MEDS_tabular_automl/utils.py (1 hunks)
  • tests/test_integration.py (1 hunks)
  • tests/test_tabularize.py (1 hunks)
Files skipped from review due to trivial changes (5)
  • .pre-commit-config.yaml
  • src/MEDS_tabular_automl/configs/default.yaml
  • src/MEDS_tabular_automl/configs/tabularization.yaml
  • src/MEDS_tabular_automl/configs/tabularization/default.yaml
  • src/MEDS_tabular_automl/configs/task_specific_caching.yaml
Additional context used
Ruff
src/MEDS_tabular_automl/scripts/tabularize_time_series.py

8-8: Module level import not at top of file (E402)


9-9: Module level import not at top of file (E402)


10-10: Module level import not at top of file (E402)


12-12: Module level import not at top of file (E402)


13-13: Module level import not at top of file (E402)


14-14: Module level import not at top of file (E402)


15-15: Module level import not at top of file (E402)


17-17: Module level import not at top of file (E402)


18-18: Module level import not at top of file (E402)


19-19: Module level import not at top of file (E402)


20-20: Module level import not at top of file (E402)


21-21: Module level import not at top of file (E402)


22-29: Module level import not at top of file (E402)

src/MEDS_tabular_automl/scripts/tabularize_static.py

13-13: Module level import not at top of file (E402)


15-15: Module level import not at top of file (E402)


17-23: Module level import not at top of file (E402)


24-24: Module level import not at top of file (E402)


25-25: Module level import not at top of file (E402)


26-26: Module level import not at top of file (E402)


27-34: Module level import not at top of file (E402)

src/MEDS_tabular_automl/generate_summarized_reps.py

6-6: Module level import not at top of file (E402)


7-7: Module level import not at top of file (E402)


9-9: Module level import not at top of file (E402)


10-10: Module level import not at top of file (E402)

tests/test_integration.py

5-5: Module level import not at top of file (E402)


6-6: Module level import not at top of file (E402)


7-7: Module level import not at top of file (E402)


8-8: Module level import not at top of file (E402)


9-9: Module level import not at top of file (E402)


11-11: Module level import not at top of file (E402)


12-12: Module level import not at top of file (E402)


13-22: Module level import not at top of file (E402)


24-24: Module level import not at top of file (E402)


25-25: Module level import not at top of file (E402)


26-33: Module level import not at top of file (E402)

tests/test_tabularize.py

5-5: Module level import not at top of file (E402)


6-6: Module level import not at top of file (E402)


7-7: Module level import not at top of file (E402)


8-8: Module level import not at top of file (E402)


9-9: Module level import not at top of file (E402)


11-11: Module level import not at top of file (E402)


12-12: Module level import not at top of file (E402)


14-14: Module level import not at top of file (E402)


15-15: Module level import not at top of file (E402)


16-23: Module level import not at top of file (E402)


24-31: Module level import not at top of file (E402)

LanguageTool
README.md

[uncategorized] ~62-~62: Possible missing comma found. (AI_HYDRA_LEO_MISSING_COMMA)
Context: ...ls and caches them with the event_id indexes which align them with the nearest prior...


[grammar] ~83-~83: It appears that a hyphen is missing in the plural noun “to-dos”? (TO_DO_HYPHEN)
Context: ...bularize.yaml`. ## AutoML Pipelines # TODOs 1. Leverage the "event bound aggregati...


[uncategorized] ~99-~99: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ... Configuration File - MEDS_cohort_dir: directory of MEDS format dataset that i...


[grammar] ~99-~99: Possible verb agreement error. Did you mean “formats”? (Some collective nouns can be treated as both singular and plural, so ‘format’ is not always incorrect.) (COLLECTIVE_NOUN_VERB_AGREEMENT_VBP)
Context: ... - MEDS_cohort_dir: directory of MEDS format dataset that is ingested. - `tabularize...


[uncategorized] ~100-~100: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...at is ingested. - tabularized_data_dir: output directory of tabularized data. -...


[uncategorized] ~101-~101: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...d data. - min_code_inclusion_frequency: The base feature inclusion frequency th...


[uncategorized] ~106-~106: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...ype inclusion cutoff. - window_sizes: Beyond writing out a raw, per-event fla...


[uncategorized] ~111-~111: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ... the raw representation files. - codes: A list of codes to include in the flat ...


[uncategorized] ~113-~113: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...d in the flat representation. - aggs: A list of aggregations to apply to the ...


[uncategorized] ~114-~114: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...ter than 0. - n_patients_per_sub_shard: The number of subjects that should be i...


[uncategorized] ~117-~117: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ... more memory efficient. - do_overwrite: If True, this function will overwrite...


[uncategorized] ~119-~119: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...in the target save directory. - seed: The seed to use for random number gener...

Markdownlint
README.md

6-6: Expected: h2; Actual: h4 (MD001, heading-increment)
Heading levels should only increment by one level at a time


45-45: Expected: h2; Actual: h3 (MD001, heading-increment)
Heading levels should only increment by one level at a time


27-27: null (MD025, single-title, single-h1)
Multiple top-level headings in the same document


31-31: null (MD025, single-title, single-h1)
Multiple top-level headings in the same document


83-83: null (MD025, single-title, single-h1)
Multiple top-level headings in the same document


97-97: null (MD025, single-title, single-h1)
Multiple top-level headings in the same document

Additional comments not posted (65)
src/MEDS_tabular_automl/configs/describe_codes.yaml (1)

1-14: Configuration settings in describe_codes.yaml appear consistent and well-structured.

.github/workflows/tests.yaml (3)

22-25: Update to Python 3.12 aligns with modern Python practices.


29-29: Inclusion of specific test dependencies enhances the testing strategy.


36-36: Ensure the hf_cohort/ directory is correctly ignored for valid reasons.

Verification successful

The hf_cohort/ directory is referenced in the README.md file, indicating it contains scripts for generating task labels and processing datasets. However, the directory itself does not exist in the repository. This suggests that the directory might be dynamically created or is expected to be present in a different environment.

  • README.md: References to hf_cohort/ scripts for generating labels and processing datasets.

Given this context, it seems reasonable to exclude the hf_cohort/ directory from tests, likely because it is not part of the repository or is environment-specific.

Scripts executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Description: Verify the reason for ignoring `hf_cohort/` directory in tests.

# Test: Check for specific configurations or data in `hf_cohort/` that justify its exclusion.
rg --type yaml --type python "hf_cohort" --files-with-matches

Length of output: 92



Script:

#!/bin/bash
# Description: Verify the reason for ignoring `hf_cohort/` directory in tests.

# Test: Check for specific configurations or data in `hf_cohort/` that justify its exclusion.
rg --type yaml --type py "hf_cohort" --files-with-matches

Length of output: 57



Script:

#!/bin/bash
# Description: List the contents of the `hf_cohort/` directory to understand its purpose.

# List the contents of the `hf_cohort/` directory.
ls -R hf_cohort/

Length of output: 76



Script:

#!/bin/bash
# Description: Search for references to `hf_cohort` in the entire repository to understand its context.

# Search for references to `hf_cohort` in all files.
rg "hf_cohort"

Length of output: 336

src/MEDS_tabular_automl/file_name.py (1)

5-28: Enhancements in file management functions appear logically sound and well-implemented.

pyproject.toml (3)

6-7: Addition of new project authors is accurately reflected.

Also applies to: 9-9


17-17: Updated dependencies support new functionalities and requirements.


37-38: Updated project URLs provide necessary resources and support.

src/MEDS_tabular_automl/configs/launch_xgboost.yaml (1)

1-81: Comprehensive and well-structured configuration for XGBoost model launching.

src/MEDS_tabular_automl/configs/tmp.yaml.yaml (1)

1-89: Temporary configuration settings are comprehensive and align with project standards.

src/MEDS_tabular_automl/scripts/sweep_xgboost.py (2)

20-32: The function generate_permutations is well-implemented and documented. It correctly generates all possible combinations of a list, which is essential for parameter tuning in XGBoost.


51-82: The main function is well-structured and effectively sets up and executes the Optuna study for XGBoost optimization.

src/MEDS_tabular_automl/scripts/cache_task.py (2)

40-50: The function generate_row_cached_matrix correctly checks matrix dimensions and processes the data, ensuring data integrity before transformation.


53-97: The main function in cache_task.py is well-organized and effectively handles the caching of time-series data, ensuring efficient data processing and storage.

src/MEDS_tabular_automl/scripts/describe_codes.py (3)

58-60: The compute_fn function effectively computes feature frequencies, which is crucial for filtering out infrequent events in the dataset.


61-62: The write_fn function correctly handles the writing of dataframes to the output path, ensuring data persistence.


33-112: The main function in describe_codes.py is well-structured and effectively manages the computation and storage of feature frequencies, facilitating efficient data analysis.

src/MEDS_tabular_automl/scripts/tabularize_time_series.py (3)

88-89: The read_fn function effectively filters the input data based on resolved codes, ensuring that only relevant data is processed.


91-106: The compute_fn function in tabularize_time_series.py correctly processes and summarizes time-series data, ensuring accurate data aggregation.


36-123: The main function is well-organized and effectively orchestrates the data processing pipeline for summarizing time-series data, ensuring efficient and structured data storage.

src/MEDS_tabular_automl/generate_ts_features.py (3)

20-23: The function feature_name_to_code correctly converts feature names to codes, which is essential for processing time series data.


25-38: The get_long_code_df function effectively pivots the codes data frame to a long format one-hot representation, which is crucial for time series analysis.


119-160: The function get_flat_ts_rep in generate_ts_features.py correctly produces a flat time series representation from a given data frame, ensuring accurate feature extraction and summarization.

src/MEDS_tabular_automl/scripts/tabularize_static.py (3)

136-137: The read_fn function effectively filters the input data based on resolved codes, ensuring that only relevant data is processed.


139-144: The compute_fn function in tabularize_static.py correctly processes and summarizes static data, ensuring accurate data aggregation.


41-160: The main function is well-organized and effectively orchestrates the data processing pipeline for tabularizing static data, ensuring efficient and structured data storage.

README.md (4)

Line range hint 7-18: Excellent explanation of "tabular data" in the context of MEDS datasets. This clarification helps in setting the right context for new users.

Tools
Markdownlint

6-6: Expected: h2; Actual: h4 (MD001, heading-increment)
Heading levels should only increment by one level at a time


99-119: The YAML configuration details are well-documented. Ensure that all default values and options are clearly defined to avoid any confusion during setup.

Tools
LanguageTool

[uncategorized] ~99-~99: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ... Configuration File - MEDS_cohort_dir: directory of MEDS format dataset that i...


[grammar] ~99-~99: Possible verb agreement error. Did you mean “formats”? (Some collective nouns can be treated as both singular and plural, so ‘format’ is not always incorrect.) (COLLECTIVE_NOUN_VERB_AGREEMENT_VBP)
Context: ... - MEDS_cohort_dir: directory of MEDS format dataset that is ingested. - `tabularize...


[uncategorized] ~100-~100: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...at is ingested. - tabularized_data_dir: output directory of tabularized data. -...


[uncategorized] ~101-~101: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...d data. - min_code_inclusion_frequency: The base feature inclusion frequency th...


[uncategorized] ~106-~106: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...ype inclusion cutoff. - window_sizes: Beyond writing out a raw, per-event fla...


[uncategorized] ~111-~111: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ... the raw representation files. - codes: A list of codes to include in the flat ...


[uncategorized] ~113-~113: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...d in the flat representation. - aggs: A list of aggregations to apply to the ...


[uncategorized] ~114-~114: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...ter than 0. - n_patients_per_sub_shard: The number of subjects that should be i...


[uncategorized] ~117-~117: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ... more memory efficient. - do_overwrite: If True, this function will overwrite...


[uncategorized] ~119-~119: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...in the target save directory. - seed: The seed to use for random number gener...


47-64: Ensure the scripts mentioned are correctly linked and accessible in the repository, especially for new users who might follow these as a starting point.

Tools
LanguageTool

[uncategorized] ~62-~62: Possible missing comma found. (AI_HYDRA_LEO_MISSING_COMMA)
Context: ...ls and caches them with the event_id indexes which align them with the nearest prior...


21-25: Investigate the scalability of TemporAI for handling MEDS datasets as mentioned. This could potentially streamline the process if feasible.

src/MEDS_tabular_automl/generate_static_features.py (3)

27-42: The function convert_to_matrix is well-implemented with clear documentation. Ensure that the data types of the matrix elements are consistent with the expected types in downstream processing.


85-154: The function summarize_static_measurements provides a robust mechanism for aggregating static measurements. Ensure that all edge cases are handled, especially with data types and missing values.


157-182: The function get_flat_static_rep is crucial for generating a flat representation of static data. Verify that the output is correctly integrated into the overall data pipeline and that the data types are consistent.

src/MEDS_tabular_automl/describe_codes.py (2)

13-64: The function compute_feature_frequencies is well-documented and effectively computes feature frequencies. Ensure that the function is tested thoroughly, especially with edge cases involving unusual data distributions.


67-92: The function convert_to_freq_dict is a key component in converting dataframes to frequency dictionaries. Verify that it handles all possible data types and structures expected in the input dataframes.

src/MEDS_tabular_automl/generate_summarized_reps.py (4)

13-26: The function sparse_aggregate is well-implemented with support for multiple aggregation methods. Ensure that the function is tested with various types of sparse matrices to verify its robustness.


80-127: The function compute_agg applies aggregations to dataframes based on specified window sizes and aggregation methods. Ensure that the function handles all edge cases, especially with non-standard window sizes or aggregation methods.


178-245: The function generate_summary is crucial for generating summarized data frames. Verify that all feature columns specified are included in the final output and that default values are correctly applied where necessary.


44-77: The function aggregate_matrix is key to aggregating matrices based on window indices. Verify that the aggregation methods are correctly applied and that the output matrix has the expected dimensions and data types.

tests/test_integration.py (6)

36-43: Function run_command is well-implemented and handles subprocess execution and error reporting effectively.


47-87: The environment setup in the integration test is correctly implemented, ensuring proper directory and file management.


89-102: Script execution and output verification for describe_codes are correctly implemented, with appropriate assertions to check the output files and data integrity.


106-153: Execution and verification of the static data tabularization script are well-handled, with thorough checks for output file existence and data integrity.


161-208: The time series data tabularization script execution and its output verification are correctly implemented, ensuring data integrity and correct file handling.


210-250: Setup and execution of the task-specific caching script are correctly implemented, aligning with the overall testing strategy.

src/MEDS_tabular_automl/mapper.py (3)

14-53: Function get_earliest_lock is correctly implemented, effectively retrieving the earliest lock from a specified directory. The provided example is clear and demonstrates the function's usage well.


56-82: Function register_lock is well-implemented, correctly registering a lock file in the specified directory. The example provided effectively demonstrates the function's correct behavior.


85-278: Function wrap is robustly implemented, providing a comprehensive solution for wrapping file-in file-out transformations with caching and locking. The examples are detailed and effectively illustrate the function's capabilities and edge cases.

src/MEDS_tabular_automl/scripts/launch_xgboost.py (1)

140-157: Ensure that _get_code_set handles edge cases and errors gracefully.

Consider adding error handling for cases where configuration settings might lead to unexpected results.

src/MEDS_tabular_automl/utils.py (16)

20-20: Update to pl.LazyFrame for DF_T aligns with the need for lazy evaluation in data processing.


21-21: Setting WRITE_USE_PYARROW to True ensures efficient data serialization using PyArrow.


22-22: Introduction of ROW_IDX_NAME as a constant improves code readability and maintainability.


70-73: The conversion to a sparse matrix using coo_array is efficient for handling large sparse datasets.


76-77: Using np.result_type to determine the minimum data type is a good practice for optimizing memory usage.


80-93: Optimization of data types in sparse_matrix_to_array is crucial for performance, especially in large datasets.


129-137: The parse_static_feature_column function is used effectively here to determine data types based on aggregation. This is a good example of code reuse.


186-216: The get_static_feature_cols function efficiently generates feature columns from static data. The use of list comprehensions and set operations is commendable for performance.


219-248: The get_ts_feature_cols function mirrors the static version but for time-series data. The consistent approach enhances maintainability.


274-300: Combining static and time-series feature columns in get_flat_rep_feature_cols is a key functionality for feature engineering in time-series analysis.


303-336: The load_meds_data function is well-documented and handles different data splits efficiently. The use of globbing to find files is a practical approach.


349-358: The get_unique_time_events_df function ensures data integrity by maintaining unique timestamps. The use of assertions is a good practice for debugging.


362-371: The get_feature_names function is a utility that aids in selecting specific feature columns based on aggregation types. This function enhances modularity by separating concerns.


374-378: The get_feature_indices function leverages the get_feature_names function to map features to their indices, which is crucial for operations that require direct index access.


423-444: The get_shard_prefix function is a utility that extracts relative file paths. This is useful for managing file paths in a consistent manner across different modules.


339-346: The get_events_df function effectively filters and prepares event data. Ensure that the feature columns are correctly specified to avoid runtime issues.

Verification successful

The feature columns passed to the get_events_df function are correctly specified in the provided instances.

  • tests/test_tabularize.py
  • tests/test_integration.py
  • src/MEDS_tabular_automl/generate_ts_features.py
  • src/MEDS_tabular_automl/generate_static_features.py
Scripts executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Description: Verify that the feature columns passed to `get_events_df` are correctly specified and exist in the data.

# Test: Search for the function usage and check feature column specifications. Expect: Correctly specified feature columns.
rg --type python $'get_events_df'

Length of output: 63



Script:

#!/bin/bash
# Description: Verify that the feature columns passed to `get_events_df` are correctly specified and exist in the data.

# Test: Search for the function usage and check feature column specifications. Expect: Correctly specified feature columns.
rg --type py $'get_events_df'

Length of output: 1352



Script:

#!/bin/bash
# Description: Extract lines around the usage of `get_events_df` to verify feature column specifications.

# Extract lines around the usage of `get_events_df` in relevant files.
rg --type py -A 3 -B 3 $'get_events_df' tests/test_tabularize.py tests/test_integration.py src/MEDS_tabular_automl/generate_ts_features.py src/MEDS_tabular_automl/generate_static_features.py

Length of output: 6560

Comment on lines +73 to +91
def _get_code_masks(self, feature_columns: list, codes_set: set) -> Mapping[str, list[bool]]:
"""Create boolean masks for filtering features.

Creates a dictionary of boolean masks for each aggregation type. The masks are used to filter
the feature columns based on the specified included codes and minimum code inclusion frequency.

Args:
feature_columns: List of feature columns.
codes_set: Set of codes to include.

Returns:
Dictionary of code masks for each aggregation.
"""
code_masks = {}
for agg in set(self.cfg.tabularization.aggs):
feature_ids = get_feature_indices(agg, feature_columns)
code_mask = [True if idx in codes_set else False for idx in feature_ids]
code_masks[agg] = code_mask
return code_masks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize the list comprehension in _get_code_masks.

Consider using numpy arrays or other efficient data structures for large-scale operations to improve performance.

Comment on lines +94 to +108
def _load_matrix(self, path: Path) -> sp.csc_matrix:
"""Load a sparse matrix from disk.

Args:
- path (Path): Path to the sparse matrix.

Returns:
- sp.csc_matrix: Sparse matrix.
"""
npzfile = np.load(path)
array, shape = npzfile["array"], npzfile["shape"]
if array.shape[0] != 3:
raise ValueError(f"Expected array to have 3 rows, but got {array.shape[0]} rows")
data, row, col = array
return sp.csc_matrix((data, (row, col)), shape=shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validate the loaded matrix shape more robustly in _load_matrix.

Consider adding more comprehensive checks for the matrix dimensions to ensure data integrity.

Comment on lines +111 to +137
def load_labels(self) -> tuple[Mapping[int, list], Mapping[int, list]]:
"""Loads valid event ids and labels for each shard.

Returns:
- Tuple[Mapping[int, list], Mapping[int, list]]: Tuple containing:
dictionary from shard number to list of valid event ids -- used for indexing rows
in the sparse matrix
dictionary from shard number to list of labels for these valid event ids
"""
label_fps = {
shard: (Path(self.cfg.input_label_dir) / self.split / shard).with_suffix(".parquet")
for shard in self._data_shards
for shard in self._data_shards
}
cached_labels, cached_event_ids = dict(), dict()
for shard, label_fp in label_fps.items():
label_df = pl.scan_parquet(label_fp)
cached_event_ids[shard] = label_df.select(pl.col("event_id")).collect().to_series()

# TODO: check this for Nan or any other case we need to worry about
cached_labels[shard] = label_df.select(pl.col("label")).collect().to_series()
if self.cfg.model_params.iterator.binarize_task:
cached_labels[shard] = cached_labels[shard].map_elements(
lambda x: 1 if x > 0 else 0, return_dtype=pl.Int8
)

return cached_event_ids, cached_labels
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Address the TODO in load_labels and ensure robust error handling.

Would you like me to help implement the error handling or open a GitHub issue to track this task?

Comment on lines +387 to +421
@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem)
def main(cfg: DictConfig) -> float:
"""Optimize the model based on the provided configuration.

Args:
- cfg (DictConfig): Configuration dictionary.

Returns:
- float: Evaluation result.
"""
if not cfg.loguru_init:
hydra_loguru_init()

model = XGBoostModel(cfg)
model.train()

print(
"Time Profiling for window sizes ",
f"{cfg.tabularization.window_sizes} and min ",
"code frequency of {cfg.tabularization.min_code_inclusion_frequency}:",
)
print("Train Time: \n", model._profile_durations())
print("Train Iterator Time: \n", model.itrain._profile_durations())
print("Tuning Iterator Time: \n", model.ituning._profile_durations())
print("Held Out Iterator Time: \n", model.iheld_out._profile_durations())

# save model
save_dir = Path(cfg.output_dir)
save_dir.mkdir(parents=True, exist_ok=True)

logger.info(f"Saving the model to directory: {save_dir}")
model.model.save_model(save_dir / "model.json")
auc = model.evaluate()
logger.info(f"AUC: {auc}")
return auc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize the main function in launch_xgboost.py.

Consider refactoring the main function to separate configuration loading, model training, and evaluation into distinct functions or methods for better modularity and testability.

Comment on lines +201 to +334
len(list_subdir_files(Path(cfg.input_dir).parent, "parquet")) == 4
), "MEDS train split Data Files Should be 4!"
for f in meds_files:
assert pl.read_parquet(f).shape[0] > 0, "MEDS Data Tabular Dataframe Should not be Empty!"
split_json = json.load(StringIO(SPLITS_JSON))
splits_fp = MEDS_cohort_dir / "splits.json"
json.dump(split_json, splits_fp.open("w"))
# Step 1: Describe Codes - compute code frequencies
describe_codes.main(cfg)

assert (Path(cfg.output_dir) / "config.yaml").is_file()
assert Path(cfg.output_filepath).is_file()

feature_columns = get_feature_columns(cfg.output_filepath)
assert get_feature_names("code/count", feature_columns) == sorted(CODE_COLS)
assert get_feature_names("static/present", feature_columns) == sorted(STATIC_PRESENT_COLS)
assert get_feature_names("static/first", feature_columns) == sorted(STATIC_FIRST_COLS)
for value_agg in VALUE_AGGREGATIONS:
assert get_feature_names(value_agg, feature_columns) == sorted(VALUE_COLS)

# Step 2: Tabularization
tabularize_static_config = {
"MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()),
"do_overwrite": False,
"seed": 1,
"hydra.verbose": True,
"tqdm": False,
"loguru_init": True,
"tabularization.min_code_inclusion_frequency": 1,
"tabularization.aggs": "[static/present,static/first,code/count,value/sum]",
"tabularization.window_sizes": "[30d,365d,full]",
}

with initialize(
version_base=None, config_path="../src/MEDS_tabular_automl/configs/"
): # path to config.yaml
overrides = [f"{k}={v}" for k, v in tabularize_static_config.items()]
cfg = compose(config_name="tabularization", overrides=overrides) # config.yaml
tabularize_static.main(cfg)
output_files = list(Path(cfg.output_dir).glob("**/static/**/*.npz"))
actual_files = [get_shard_prefix(Path(cfg.output_dir), each) + ".npz" for each in output_files]
assert set(actual_files) == set(EXPECTED_STATIC_FILES)
# Check the files are not empty
for f in output_files:
static_matrix = load_matrix(f)
assert static_matrix.shape[0] > 0, "Static Data Tabular Dataframe Should not be Empty!"
expected_num_cols = len(get_feature_names(f"static/{f.stem}", feature_columns))
assert static_matrix.shape[1] == expected_num_cols, (
f"Static Data Tabular Dataframe Should have {expected_num_cols}"
f"Columns but has {static_matrix.shape[1]}!"
)
split = f.parts[-5]
shard_num = f.parts[-4]
med_shard_fp = (Path(cfg.input_dir) / split / shard_num).with_suffix(".parquet")
expected_num_rows = (
get_unique_time_events_df(get_events_df(pl.scan_parquet(med_shard_fp), feature_columns))
.collect()
.shape[0]
)
assert static_matrix.shape[0] == expected_num_rows, (
f"Static Data matrix Should have {expected_num_rows}"
f" rows but has {static_matrix.shape[0]}!"
)
allowed_codes = cfg.tabularization._resolved_codes
num_allowed_codes = len(allowed_codes)
feature_columns = get_feature_columns(cfg.tabularization.filtered_code_metadata_fp)
assert num_allowed_codes == len(
feature_columns
), f"Should have {len(feature_columns)} codes but has {num_allowed_codes}"

tabularize_time_series.main(cfg)

# confirm summary files exist:
output_files = list_subdir_files(cfg.output_dir, "npz")
actual_files = [
get_shard_prefix(Path(cfg.output_dir), each) + ".npz"
for each in output_files
if "none/static" not in str(each)
]
assert set(actual_files) == set(SUMMARIZE_EXPECTED_FILES)
for f in output_files:
ts_matrix = load_matrix(f)
assert ts_matrix.shape[0] > 0, "Time-Series Tabular Dataframe Should not be Empty!"
expected_num_cols = len(get_feature_names(f"{f.parent.stem}/{f.stem}", feature_columns))
assert ts_matrix.shape[1] == expected_num_cols, (
f"Time-Series Tabular Dataframe Should have {expected_num_cols}"
f"Columns but has {ts_matrix.shape[1]}!"
)
split = f.parts[-5]
shard_num = f.parts[-4]
med_shard_fp = (Path(cfg.input_dir) / split / shard_num).with_suffix(".parquet")
expected_num_rows = (
get_unique_time_events_df(get_events_df(pl.scan_parquet(med_shard_fp), feature_columns))
.collect()
.shape[0]
)
assert ts_matrix.shape[0] == expected_num_rows, (
f"Time-Series Data matrix Should have {expected_num_rows}"
f" rows but has {ts_matrix.shape[0]}!"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor test_tabularize to improve readability and maintainability.

Consider breaking down the test_tabularize function into smaller, more focused test functions. This can improve readability and make it easier to identify and address specific failures.

Comment on lines +107 to +123
def write_df(df: coo_array, fp: Path, **kwargs):
"""Write shard to disk."""
do_overwrite = kwargs.get("do_overwrite", False)

if not do_overwrite and fp.is_file():
raise FileExistsError(f"{fp} exists and do_overwrite is {do_overwrite}!")

fp.parent.mkdir(exist_ok=True, parents=True)

if isinstance(df, pl.LazyFrame):
df.collect().write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW)
elif isinstance(df, pl.DataFrame):
df.write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW)
elif isinstance(df, coo_array):
store_matrix(df, fp)
else:
raise TypeError(f"Unsupported type for df: {type(df)}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The write_df function's flexibility to handle different data types and overwrite behavior is well-implemented. However, ensure that the error messages are informative and guide the user on how to resolve potential issues.

- raise FileExistsError(f"{fp} exists and do_overwrite is {do_overwrite}!")
+ raise FileExistsError(f"File {fp} exists. Set 'do_overwrite' to True to overwrite it.")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def write_df(df: coo_array, fp: Path, **kwargs):
"""Write shard to disk."""
do_overwrite = kwargs.get("do_overwrite", False)
if not do_overwrite and fp.is_file():
raise FileExistsError(f"{fp} exists and do_overwrite is {do_overwrite}!")
fp.parent.mkdir(exist_ok=True, parents=True)
if isinstance(df, pl.LazyFrame):
df.collect().write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW)
elif isinstance(df, pl.DataFrame):
df.write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW)
elif isinstance(df, coo_array):
store_matrix(df, fp)
else:
raise TypeError(f"Unsupported type for df: {type(df)}")
def write_df(df: coo_array, fp: Path, **kwargs):
"""Write shard to disk."""
do_overwrite = kwargs.get("do_overwrite", False)
if not do_overwrite and fp.is_file():
raise FileExistsError(f"File {fp} exists. Set 'do_overwrite' to True to overwrite it.")
fp.parent.mkdir(exist_ok=True, parents=True)
if isinstance(df, pl.LazyFrame):
df.collect().write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW)
elif isinstance(df, pl.DataFrame):
df.write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW)
elif isinstance(df, coo_array):
store_matrix(df, fp)
else:
raise TypeError(f"Unsupported type for df: {type(df)}")

Comment on lines +251 to +271
def get_prediction_ts_cols(
aggregations: list[str], ts_feature_cols: DF_T, window_sizes: list[str] | None = None
) -> list[str]:
"""Generates a list of feature column names that will be used for downstream task
Examples:
>>> feature_cols = ['A/code', 'A/value', 'C/code', 'C/value']
>>> window_sizes = None
>>> aggs = ['value/sum', 'code/count']
>>> get_prediction_ts_cols(aggs, feature_cols, window_sizes)
error
>>> window_sizes = ["1d"]
>>> get_prediction_ts_cols(aggs, feature_cols, window_sizes)
error
"""
agg_feature_columns = []
for code in ts_feature_cols:
ts_aggregations = [f"{code}/{agg}" for agg in aggregations]
agg_feature_columns.extend(ts_aggregations)
if window_sizes:
ts_aggregations = [f"{window_size}/{code}" for window_size in window_sizes]
return sorted(ts_aggregations)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_prediction_ts_cols function should handle the case where window_sizes is None to avoid potential errors.

-    if window_sizes:
+    if window_sizes is not None:
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def get_prediction_ts_cols(
aggregations: list[str], ts_feature_cols: DF_T, window_sizes: list[str] | None = None
) -> list[str]:
"""Generates a list of feature column names that will be used for downstream task
Examples:
>>> feature_cols = ['A/code', 'A/value', 'C/code', 'C/value']
>>> window_sizes = None
>>> aggs = ['value/sum', 'code/count']
>>> get_prediction_ts_cols(aggs, feature_cols, window_sizes)
error
>>> window_sizes = ["1d"]
>>> get_prediction_ts_cols(aggs, feature_cols, window_sizes)
error
"""
agg_feature_columns = []
for code in ts_feature_cols:
ts_aggregations = [f"{code}/{agg}" for agg in aggregations]
agg_feature_columns.extend(ts_aggregations)
if window_sizes:
ts_aggregations = [f"{window_size}/{code}" for window_size in window_sizes]
return sorted(ts_aggregations)
def get_prediction_ts_cols(
aggregations: list[str], ts_feature_cols: DF_T, window_sizes: list[str] | None = None
) -> list[str]:
"""Generates a list of feature column names that will be used for downstream task
Examples:
>>> feature_cols = ['A/code', 'A/value', 'C/code', 'C/value']
>>> window_sizes = None
>>> aggs = ['value/sum', 'code/count']
>>> get_prediction_ts_cols(aggs, feature_cols, window_sizes)
error
>>> window_sizes = ["1d"]
>>> get_prediction_ts_cols(aggs, feature_cols, window_sizes)
error
"""
agg_feature_columns = []
for code in ts_feature_cols:
ts_aggregations = [f"{code}/{agg}" for agg in aggregations]
agg_feature_columns.extend(ts_aggregations)
if window_sizes is not None:
ts_aggregations = [f"{window_size}/{code}" for window_size in window_sizes]
return sorted(ts_aggregations)

Comment on lines +381 to +421
def store_config_yaml(config_fp: Path, cfg: DictConfig):
"""Stores configuration parameters into a JSON file.

This function writes a dictionary of parameters, which includes patient partitioning
information and configuration details, to a specified JSON file.

Args:
- config_fp (Path): The file path for the JSON file where config should be stored.
- cfg (DictConfig): A configuration object containing settings like the number of patients
per sub-shard, minimum code inclusion frequency, and flags for updating or overwriting existing files.

Behavior:
- If config_fp exists and cfg.do_overwrite is False (without do_update being True), a
FileExistsError is raised to prevent unintentional data loss.

Raises:
- ValueError: If there are discrepancies between old and new parameters during an update.
- FileExistsError: If the file exists and overwriting is not allowed.

Example:
>>> cfg = DictConfig({
... "n_patients_per_sub_shard": 100,
... "min_code_inclusion_frequency": 5,
... "do_overwrite": True,
... })
>>> import tempfile
>>> from pathlib import Path
>>> with tempfile.NamedTemporaryFile() as temp_f:
... config_fp = Path(temp_f.name)
... store_config_yaml(config_fp, cfg)
... assert config_fp.exists()
... store_config_yaml(config_fp, cfg)
... cfg.do_overwrite = False
... try:
... store_config_yaml(config_fp, cfg)
... except FileExistsError as e:
... print("FileExistsError Error Triggered")
FileExistsError Error Triggered
"""
OmegaConf.save(cfg, config_fp)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The store_config_yaml function is robust and handles configuration storage with appropriate error handling. Consider adding a check for do_update in the behavior description to align with the implementation.

+    - do_update (bool): If True, allows updating existing files even if `do_overwrite` is False.
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def store_config_yaml(config_fp: Path, cfg: DictConfig):
"""Stores configuration parameters into a JSON file.
This function writes a dictionary of parameters, which includes patient partitioning
information and configuration details, to a specified JSON file.
Args:
- config_fp (Path): The file path for the JSON file where config should be stored.
- cfg (DictConfig): A configuration object containing settings like the number of patients
per sub-shard, minimum code inclusion frequency, and flags for updating or overwriting existing files.
Behavior:
- If config_fp exists and cfg.do_overwrite is False (without do_update being True), a
FileExistsError is raised to prevent unintentional data loss.
Raises:
- ValueError: If there are discrepancies between old and new parameters during an update.
- FileExistsError: If the file exists and overwriting is not allowed.
Example:
>>> cfg = DictConfig({
... "n_patients_per_sub_shard": 100,
... "min_code_inclusion_frequency": 5,
... "do_overwrite": True,
... })
>>> import tempfile
>>> from pathlib import Path
>>> with tempfile.NamedTemporaryFile() as temp_f:
... config_fp = Path(temp_f.name)
... store_config_yaml(config_fp, cfg)
... assert config_fp.exists()
... store_config_yaml(config_fp, cfg)
... cfg.do_overwrite = False
... try:
... store_config_yaml(config_fp, cfg)
... except FileExistsError as e:
... print("FileExistsError Error Triggered")
FileExistsError Error Triggered
"""
OmegaConf.save(cfg, config_fp)
def store_config_yaml(config_fp: Path, cfg: DictConfig):
"""Stores configuration parameters into a JSON file.
This function writes a dictionary of parameters, which includes patient partitioning
information and configuration details, to a specified JSON file.
Args:
- config_fp (Path): The file path for the JSON file where config should be stored.
- cfg (DictConfig): A configuration object containing settings like the number of patients
per sub-shard, minimum code inclusion frequency, and flags for updating or overwriting existing files.
- do_update (bool): If True, allows updating existing files even if `do_overwrite` is False.
Behavior:
- If config_fp exists and cfg.do_overwrite is False (without do_update being True), a
FileExistsError is raised to prevent unintentional data loss.
Raises:
- ValueError: If there are discrepancies between old and new parameters during an update.
- FileExistsError: If the file exists and overwriting is not allowed.
Example:
>>> cfg = DictConfig({
... "n_patients_per_sub_shard": 100,
... "min_code_inclusion_frequency": 5,
... "do_overwrite": True,
... })
>>> import tempfile
>>> from pathlib import Path
>>> with tempfile.NamedTemporaryFile() as temp_f:
... config_fp = Path(temp_f.name)
... store_config_yaml(config_fp, cfg)
... assert config_fp.exists()
... store_config_yaml(config_fp, cfg)
... cfg.do_overwrite = False
... try:
... store_config_yaml(config_fp, cfg)
... except FileExistsError as e:
... print("FileExistsError Error Triggered")
FileExistsError Error Triggered
"""
OmegaConf.save(cfg, config_fp)

Comment on lines +38 to +48
def xgboost_singleton(trial: optuna.Trial, config: DictConfig) -> float:
for key, value in config.optuna.params.suggest_categorical.items():
logger.info(f"Optimizing {key} with {value}")
config.tabularization[key] = trial.suggest_categorical(key, value)
for key, value in config.optuna.params.suggest_float.items():
with open_dict(config):
config[key] = trial.suggest_float(key, **value)
for key, value in config.optuna.params.suggest_int.items():
with open_dict(config):
config[key] = trial.suggest_int(key, **value)
return launch_xgboost.main(config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tip

Codebase Verification

Ensure isolation of configuration changes in xgboost_singleton to prevent side effects. The search results indicate that the config object is modified directly in the xgboost_singleton function. This could potentially lead to unintended side effects if the config object is shared across different parts of the application.

  • src/MEDS_tabular_automl/scripts/sweep_xgboost.py: config[key] = trial.suggest_float(key, **value)
  • src/MEDS_tabular_automl/scripts/sweep_xgboost.py: config[key] = trial.suggest_int(key, **value)
Analysis chain

Ensure isolation of configuration changes in xgboost_singleton to prevent side effects.

Scripts executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Description: Verify that changes to the configuration in `xgboost_singleton` do not affect other parts of the application.
# Test: Check for any shared state modifications that could lead to unintended side effects.
rg --type python $'config.tabularization[key] = trial.suggest_categorical(key, value)'

Length of output: 118



Script:

#!/bin/bash
# Description: Verify that changes to the configuration in `xgboost_singleton` do not affect other parts of the application.
# Test: Check for any shared state modifications that could lead to unintended side effects.
rg --type py $'config.tabularization[key] = trial.suggest_categorical(key, value)'

Length of output: 83



Script:

#!/bin/bash
# Description: Verify that changes to the configuration in `xgboost_singleton` do not affect other parts of the application.
# Test: Check for any shared state modifications that could lead to unintended side effects.
rg --type py 'config\['

Length of output: 241

Comment on lines +142 to +183
def add_static_missing_cols(
flat_df: DF_T, feature_columns: list[str], set_count_0_to_null: bool = False
) -> DF_T:
"""Normalizes columns in a DataFrame so all expected columns are present and appropriately typed.

Parameters:
- flat_df (DF_T): The DataFrame to be normalized.
- feature_columns (list[str]): A list of feature column names that should exist in the DataFrame.
- set_count_0_to_null (bool): A flag indicating whether counts of zero should be converted to nulls.

Returns:
- DF_T: The normalized DataFrame with all columns set to the correct type and zero-counts handled
if specified.

This function ensures that all necessary columns are added and typed correctly within
a DataFrame, potentially modifying zero counts to nulls based on the configuration.
"""
cols_to_add = set(feature_columns) - set(flat_df.columns)
cols_to_retype = set(feature_columns).intersection(set(flat_df.columns))

cols_to_add = [(c, get_static_col_dtype(c)) for c in cols_to_add]
cols_to_retype = [(c, get_static_col_dtype(c)) for c in cols_to_retype]

if "timestamp" in flat_df.columns:
key_cols = ["patient_id", "timestamp"]
else:
key_cols = ["patient_id"]

flat_df = flat_df.with_columns(
*[pl.lit(None, dtype=dt).alias(c) for c, dt in cols_to_add],
*[pl.col(c).cast(dt).alias(c) for c, dt in cols_to_retype],
).select(*key_cols, *feature_columns)

if not set_count_0_to_null:
return flat_df

flat_df = flat_df.collect()

flat_df = flat_df.with_columns(
pl.when(cs.ends_with("count") != 0).then(cs.ends_with("count")).keep_name()
).lazy()
return flat_df
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The add_static_missing_cols function is robust in handling missing columns and data types. Consider adding more detailed documentation on the behavior when set_count_0_to_null is True.

+    - set_count_0_to_null (bool): If True, converts zero counts to nulls. This can be useful in certain statistical analyses where zero counts are treated differently.
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def add_static_missing_cols(
flat_df: DF_T, feature_columns: list[str], set_count_0_to_null: bool = False
) -> DF_T:
"""Normalizes columns in a DataFrame so all expected columns are present and appropriately typed.
Parameters:
- flat_df (DF_T): The DataFrame to be normalized.
- feature_columns (list[str]): A list of feature column names that should exist in the DataFrame.
- set_count_0_to_null (bool): A flag indicating whether counts of zero should be converted to nulls.
Returns:
- DF_T: The normalized DataFrame with all columns set to the correct type and zero-counts handled
if specified.
This function ensures that all necessary columns are added and typed correctly within
a DataFrame, potentially modifying zero counts to nulls based on the configuration.
"""
cols_to_add = set(feature_columns) - set(flat_df.columns)
cols_to_retype = set(feature_columns).intersection(set(flat_df.columns))
cols_to_add = [(c, get_static_col_dtype(c)) for c in cols_to_add]
cols_to_retype = [(c, get_static_col_dtype(c)) for c in cols_to_retype]
if "timestamp" in flat_df.columns:
key_cols = ["patient_id", "timestamp"]
else:
key_cols = ["patient_id"]
flat_df = flat_df.with_columns(
*[pl.lit(None, dtype=dt).alias(c) for c, dt in cols_to_add],
*[pl.col(c).cast(dt).alias(c) for c, dt in cols_to_retype],
).select(*key_cols, *feature_columns)
if not set_count_0_to_null:
return flat_df
flat_df = flat_df.collect()
flat_df = flat_df.with_columns(
pl.when(cs.ends_with("count") != 0).then(cs.ends_with("count")).keep_name()
).lazy()
return flat_df
def add_static_missing_cols(
flat_df: DF_T, feature_columns: list[str], set_count_0_to_null: bool = False
) -> DF_T:
"""Normalizes columns in a DataFrame so all expected columns are present and appropriately typed.
Parameters:
- flat_df (DF_T): The DataFrame to be normalized.
- feature_columns (list[str]): A list of feature column names that should exist in the DataFrame.
- set_count_0_to_null (bool): A flag indicating whether counts of zero should be converted to nulls.
- set_count_0_to_null (bool): If True, converts zero counts to nulls. This can be useful in certain statistical analyses where zero counts are treated differently.
Returns:
- DF_T: The normalized DataFrame with all columns set to the correct type and zero-counts handled
if specified.
This function ensures that all necessary columns are added and typed correctly within
a DataFrame, potentially modifying zero counts to nulls based on the configuration.
"""
cols_to_add = set(feature_columns) - set(flat_df.columns)
cols_to_retype = set(feature_columns).intersection(set(flat_df.columns))
cols_to_add = [(c, get_static_col_dtype(c)) for c in cols_to_add]
cols_to_retype = [(c, get_static_col_dtype(c)) for c in cols_to_retype]
if "timestamp" in flat_df.columns:
key_cols = ["patient_id", "timestamp"]
else:
key_cols = ["patient_id"]
flat_df = flat_df.with_columns(
*[pl.lit(None, dtype=dt).alias(c) for c, dt in cols_to_add],
*[pl.col(c).cast(dt).alias(c) for c, dt in cols_to_retype],
).select(*key_cols, *feature_columns)
if not set_count_0_to_null:
return flat_df
flat_df = flat_df.collect()
flat_df = flat_df.with_columns(
pl.when(cs.ends_with("count") != 0).then(cs.ends_with("count")).keep_name()
).lazy()
return flat_df

Comment on lines +44 to +82
def get_sparse_static_rep(static_features, static_df, meds_df, feature_columns) -> coo_array:
"""Merges static and time-series dataframes.

This function merges the static and time-series dataframes based on the patient_id column.

Args:
- feature_columns (List[str]): A list of feature columns to include in the merged dataframe.
- static_df (pd.DataFrame): A dataframe containing static features.
- ts_df (pd.DataFrame): A dataframe containing time-series features.

Returns:
- pd.DataFrame: A merged dataframe containing static and time-series features.
"""
# Make static data sparse and merge it with the time-series data
logger.info("Make static data sparse and merge it with the time-series data")
# Check static_df is sorted and unique
assert static_df.select(pl.col("patient_id")).collect().to_series().is_sorted()
assert (
static_df.select(pl.len()).collect().item()
== static_df.select(pl.col("patient_id").n_unique()).collect().item()
)
meds_df = get_unique_time_events_df(get_events_df(meds_df, feature_columns))

# load static data as sparse matrix
static_matrix = convert_to_matrix(
static_df, num_events=meds_df.select(pl.len()).collect().item(), num_features=len(static_features)
)
# Duplicate static matrix rows to match time-series data
events_per_patient = (
meds_df.select(pl.col("patient_id").value_counts())
.unnest("patient_id")
.sort(by="patient_id")
.select(pl.col("count"))
.collect()
.to_series()
)
reindex_slices = np.repeat(range(len(events_per_patient)), events_per_patient)
static_matrix = static_matrix[reindex_slices, :]
return coo_array(static_matrix)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function get_sparse_static_rep effectively merges static and time-series data. However, consider optimizing the data handling to avoid potential performance bottlenecks with large datasets.

- static_matrix = convert_to_matrix(
-     static_df, num_events=meds_df.select(pl.len()).collect().item(), num_features=len(static_features)
- )
+ # Optimize data handling by directly using polars functions to handle large datasets efficiently.
+ static_matrix = static_df.to_coo_array(num_events=meds_df.select(pl.len()).collect().item(), num_features=len(static_features))

Committable suggestion was skipped due to low confidence.

Comment on lines +63 to +67
def parse_static_feature_column(c: str) -> tuple[str, str, str, str]:
parts = c.split("/")
if len(parts) < 3:
raise ValueError(f"Column {c} is not a valid flat feature column!")
return ("/".join(parts[:-2]), parts[-2], parts[-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tip

Codebase Verification

The inspection of the code surrounding the usage of parse_static_feature_column indicates that the function is used in contexts where the feature columns are derived from other functions like get_feature_names. However, there is no explicit validation of the format of these feature columns before they are passed to parse_static_feature_column.

  • src/MEDS_tabular_automl/utils.py: The function parse_static_feature_column is used directly without prior validation of the column format.
  • src/MEDS_tabular_automl/generate_static_features.py: The feature columns are obtained from get_feature_names, but there is no explicit check to ensure they are in the expected format before being passed to parse_static_feature_column.

To prevent runtime errors, it is recommended to add validation checks for the feature column format before they are passed to parse_static_feature_column.

  • src/MEDS_tabular_automl/utils.py
  • src/MEDS_tabular_automl/generate_static_features.py
Analysis chain

Ensure proper validation of the feature column format to prevent runtime errors.

Scripts executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Description: Verify that all feature columns passed to `parse_static_feature_column` are in the expected format.

# Test: Search for the function usage and ensure proper format. Expect: Properly formatted feature columns.
rg --type python $'parse_static_feature_column'

Length of output: 77



Script:

#!/bin/bash
# Description: Verify that all feature columns passed to `parse_static_feature_column` are in the expected format.

# Test: Search for the function usage and ensure proper format. Expect: Properly formatted feature columns.
rg --type py 'parse_static_feature_column'

Length of output: 609



Script:

#!/bin/bash
# Description: Extract and inspect the relevant lines of code around the usages of `parse_static_feature_column` to verify the format of the feature columns being passed.

# Extract lines around the usage in `src/MEDS_tabular_automl/utils.py`
rg --type py -A 3 -B 3 'parse_static_feature_column' src/MEDS_tabular_automl/utils.py

# Extract lines around the usage in `src/MEDS_tabular_automl/generate_static_features.py`
rg --type py -A 3 -B 3 'parse_static_feature_column' src/MEDS_tabular_automl/generate_static_features.py

Length of output: 1705

Comment on lines +102 to +118
def filter_to_codes(
allowed_codes: list[str] | None,
min_code_inclusion_frequency: int,
code_metadata_fp: Path,
):
"""Returns intersection of allowed codes if they are specified, and filters to codes based on inclusion
frequency."""
if allowed_codes is None:
allowed_codes = get_feature_columns(code_metadata_fp)
feature_freqs = get_feature_freqs(code_metadata_fp)

code_freqs = {
code: freq
for code, freq in feature_freqs.items()
if (freq >= min_code_inclusion_frequency and code in set(allowed_codes))
}
return sorted([code for code, freq in code_freqs.items() if freq >= min_code_inclusion_frequency])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function filter_to_codes effectively filters codes based on inclusion criteria. Consider adding more detailed logging to trace the filtering process, which could be helpful for debugging and maintenance.

+ logger.debug("Filtering codes with criteria: allowed_codes={}, min_code_inclusion_frequency={}", allowed_codes, min_code_inclusion_frequency)

Committable suggestion was skipped due to low confidence.

Comment on lines +135 to +198
def filter_parquet(fp, allowed_codes: list[str]):
"""Loads Parquet with Polars and filters to allowed codes.

Args:
fp: Path to the Meds cohort shard
allowed_codes: List of codes to filter to.

Expect:
>>> from tempfile import NamedTemporaryFile
>>> fp = NamedTemporaryFile()
>>> pl.DataFrame({
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"],
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numerical_value": [1, None, 2, 2, None, 5, None, 3]
... }).write_parquet(fp.name)
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect()
shape: (6, 3)
┌──────┬────────────┬─────────────────┐
│ code ┆ timestamp ┆ numerical_value │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞══════╪════════════╪═════════════════╡
│ A ┆ 2021-01-01 ┆ null │
│ A ┆ 2021-01-01 ┆ null │
│ D ┆ null ┆ null │
│ D ┆ null ┆ null │
│ E ┆ 2021-01-03 ┆ null │
│ E ┆ 2021-01-04 ┆ 3 │
└──────┴────────────┴─────────────────┘
>>> fp.close()
"""
df = pl.scan_parquet(fp)
# Drop values that are rare
# Drop Rare Static Codes
static_present_feature_columns = [
clear_code_aggregation_suffix(each) for each in get_feature_names("static/present", allowed_codes)
]
static_first_feature_columns = [
clear_code_aggregation_suffix(each) for each in get_feature_names("static/first", allowed_codes)
]
code_feature_columns = [
clear_code_aggregation_suffix(each) for each in get_feature_names("code/count", allowed_codes)
]
value_feature_columns = [
clear_code_aggregation_suffix(each) for each in get_feature_names("value/sum", allowed_codes)
]

is_static_code = pl.col("timestamp").is_null()
is_numeric_code = pl.col("numerical_value").is_not_null()
rare_static_code = is_static_code & ~pl.col("code").is_in(static_present_feature_columns)
rare_ts_code = ~is_static_code & ~pl.col("code").is_in(code_feature_columns)
rare_ts_value = ~is_static_code & ~pl.col("code").is_in(value_feature_columns) & is_numeric_code
rare_static_value = is_static_code & ~pl.col("code").is_in(static_first_feature_columns) & is_numeric_code

# Remove rare numeric values by converting them to null
df = df.with_columns(
pl.when(rare_static_value | rare_ts_value)
.then(None)
.otherwise(pl.col("numerical_value"))
.alias("numerical_value")
)
# Drop rows with rare codes
df = df.filter(~(rare_static_code | rare_ts_code))
return df
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function filter_parquet is crucial for filtering Parquet files based on allowed codes. Ensure that the filtering logic is optimized to handle large datasets efficiently.

- df = df.filter(~(rare_static_code | rare_ts_code))
+ # Optimize filtering logic for large datasets.
+ df = df.lazy().filter(~(rare_static_code | rare_ts_code)).collect()
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def filter_parquet(fp, allowed_codes: list[str]):
"""Loads Parquet with Polars and filters to allowed codes.
Args:
fp: Path to the Meds cohort shard
allowed_codes: List of codes to filter to.
Expect:
>>> from tempfile import NamedTemporaryFile
>>> fp = NamedTemporaryFile()
>>> pl.DataFrame({
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"],
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numerical_value": [1, None, 2, 2, None, 5, None, 3]
... }).write_parquet(fp.name)
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect()
shape: (6, 3)
┌──────┬────────────┬─────────────────┐
codetimestampnumerical_value
---------
strstri64
╞══════╪════════════╪═════════════════╡
A2021-01-01null
A2021-01-01null
Dnullnull
Dnullnull
E2021-01-03null
E2021-01-043
└──────┴────────────┴─────────────────┘
>>> fp.close()
"""
df = pl.scan_parquet(fp)
# Drop values that are rare
# Drop Rare Static Codes
static_present_feature_columns = [
clear_code_aggregation_suffix(each) for each in get_feature_names("static/present", allowed_codes)
]
static_first_feature_columns = [
clear_code_aggregation_suffix(each) for each in get_feature_names("static/first", allowed_codes)
]
code_feature_columns = [
clear_code_aggregation_suffix(each) for each in get_feature_names("code/count", allowed_codes)
]
value_feature_columns = [
clear_code_aggregation_suffix(each) for each in get_feature_names("value/sum", allowed_codes)
]
is_static_code = pl.col("timestamp").is_null()
is_numeric_code = pl.col("numerical_value").is_not_null()
rare_static_code = is_static_code & ~pl.col("code").is_in(static_present_feature_columns)
rare_ts_code = ~is_static_code & ~pl.col("code").is_in(code_feature_columns)
rare_ts_value = ~is_static_code & ~pl.col("code").is_in(value_feature_columns) & is_numeric_code
rare_static_value = is_static_code & ~pl.col("code").is_in(static_first_feature_columns) & is_numeric_code
# Remove rare numeric values by converting them to null
df = df.with_columns(
pl.when(rare_static_value | rare_ts_value)
.then(None)
.otherwise(pl.col("numerical_value"))
.alias("numerical_value")
)
# Drop rows with rare codes
df = df.filter(~(rare_static_code | rare_ts_code))
return df
# Drop rows with rare codes
# Optimize filtering logic for large datasets.
df = df.lazy().filter(~(rare_static_code | rare_ts_code)).collect()
return df

Comment on lines +29 to +41
def get_rolling_window_indicies(index_df, window_size):
"""Get the indices for the rolling windows."""
if window_size == "full":
timedelta = pd.Timedelta(150 * 52, unit="W") # just use 150 years as time delta
else:
timedelta = pd.Timedelta(window_size)
return (
index_df.with_row_index("index")
.rolling(index_column="timestamp", period=timedelta, group_by="patient_id")
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")])
.select(pl.col("min_index", "max_index"))
.collect()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function get_rolling_window_indicies effectively computes indices for rolling windows. Consider adding error handling for cases where the window_size parameter is not properly formatted.

+ if not isinstance(window_size, str):
+     raise ValueError("window_size must be a string representing a time delta.")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def get_rolling_window_indicies(index_df, window_size):
"""Get the indices for the rolling windows."""
if window_size == "full":
timedelta = pd.Timedelta(150 * 52, unit="W") # just use 150 years as time delta
else:
timedelta = pd.Timedelta(window_size)
return (
index_df.with_row_index("index")
.rolling(index_column="timestamp", period=timedelta, group_by="patient_id")
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")])
.select(pl.col("min_index", "max_index"))
.collect()
)
def get_rolling_window_indicies(index_df, window_size):
"""Get the indices for the rolling windows."""
if not isinstance(window_size, str):
raise ValueError("window_size must be a string representing a time delta.")
if window_size == "full":
timedelta = pd.Timedelta(150 * 52, unit="W") # just use 150 years as time delta
else:
timedelta = pd.Timedelta(window_size)
return (
index_df.with_row_index("index")
.rolling(index_column="timestamp", period=timedelta, group_by="patient_id")
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")])
.select(pl.col("min_index", "max_index"))
.collect()
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants