-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean #5
Conversation
…ts is correct when producing static representations
- Create individual scripts for: 1. Retrieving column names 2. Generating static representations 3. Generating dynamic representations 4. Summarizing data over windows - Add doctests for retrieving column names and generating static representations - Confirm functionality of the above tests
…dexed file types: - Code data: Files with a patient_id and timestamp index, containing columns for each code with binary presence indicators (1 and 0). - Value data: Files similar in structure but containing the numerical values observed for each code.
…n handling - Introduce VALID_AGGREGATIONS to define permissible aggregations. - Implement to generate dynamic column aliases based on window size and aggregation. - Extend for dynamic expression creation based on aggregation type and window size, handling both cumulative and windowed aggregations. - Enhance to apply specified aggregations over defined window sizes, ensuring comprehensive data summarization. - Update to handle multiple dataframes, aggregate data using specified window sizes and aggregations, and ensure inclusion of all specified feature columns, adding missing ones with default values. - Add extensive doctests to ensure accuracy of the summarization functions, demonstrating usage with both code and value data types.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
…ressing comments in other files
…toML into esgpt_caching
…arse matrices instead of the sparse dataframes for values and codes. Also started modifying the summarization script to work with the sparse dataframes.
…/sum implemented at the moment)
…arse pandas array for each patient, now we just use sparse scipy matrices
…rging rows that occur at the same time based on the current aggregation strategy. For example if the aggregation is sum, we sum up all rows on the same date, or if the aggregation is count we count up all rows on the same date.
…ock with a timestamp
Warning Rate limit exceeded@Oufattole has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 37 minutes and 59 seconds before requesting another review. How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. WalkthroughThe updates enhance the MEDS tabular AutoML system by upgrading Python to version 3.12, refining package installations, and improving test execution commands. Configuration files for various tasks like XGBoost model training, tabularization, and code description are added. Scripts and utility functions are introduced or updated to support these tasks, ensuring efficient data processing, feature generation, and model optimization. Changes
Sequence Diagram(s) (Beta)sequenceDiagram
participant User
participant Config
participant Script
participant Model
participant Data
User->>Config: Define settings
Config->>Script: Pass configurations
Script->>Data: Load and process data
Script->>Model: Train model with data
Model->>User: Provide results
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 16
Outside diff range and nitpick comments (5)
README.md (3)
2-2
: Clarify the scope of "limited automatic tabular ML pipelines" to set accurate expectations for users.
32-44
: The description of the repository's key pieces is clear and informative. However, consider adding more details about the specific ML models supported in the AutoML pipelines.
75-76
: The explanation of the tabularization process is succinct. Consider linking to thedata/tabularize.py
script directly if it's hosted in the repository for easier access.src/MEDS_tabular_automl/scripts/launch_xgboost.py (2)
24-47
: Ensure proper documentation for all attributes in the classIterator
.Consider adding detailed descriptions for each attribute in the class docstring to improve code maintainability and readability.
302-321
: Ensure thatXGBoostModel
initializes all its attributes properly.It's good practice to explicitly initialize all attributes in the constructor to avoid hidden bugs.
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (26)
- .github/workflows/tests.yaml (1 hunks)
- .pre-commit-config.yaml (1 hunks)
- README.md (3 hunks)
- pyproject.toml (1 hunks)
- src/MEDS_tabular_automl/configs/default.yaml (1 hunks)
- src/MEDS_tabular_automl/configs/describe_codes.yaml (1 hunks)
- src/MEDS_tabular_automl/configs/launch_xgboost.yaml (1 hunks)
- src/MEDS_tabular_automl/configs/tabularization.yaml (1 hunks)
- src/MEDS_tabular_automl/configs/tabularization/default.yaml (1 hunks)
- src/MEDS_tabular_automl/configs/task_specific_caching.yaml (1 hunks)
- src/MEDS_tabular_automl/configs/tmp.yaml.yaml (1 hunks)
- src/MEDS_tabular_automl/describe_codes.py (1 hunks)
- src/MEDS_tabular_automl/file_name.py (1 hunks)
- src/MEDS_tabular_automl/generate_static_features.py (1 hunks)
- src/MEDS_tabular_automl/generate_summarized_reps.py (1 hunks)
- src/MEDS_tabular_automl/generate_ts_features.py (1 hunks)
- src/MEDS_tabular_automl/mapper.py (1 hunks)
- src/MEDS_tabular_automl/scripts/cache_task.py (1 hunks)
- src/MEDS_tabular_automl/scripts/describe_codes.py (1 hunks)
- src/MEDS_tabular_automl/scripts/launch_xgboost.py (1 hunks)
- src/MEDS_tabular_automl/scripts/sweep_xgboost.py (1 hunks)
- src/MEDS_tabular_automl/scripts/tabularize_static.py (1 hunks)
- src/MEDS_tabular_automl/scripts/tabularize_time_series.py (1 hunks)
- src/MEDS_tabular_automl/utils.py (1 hunks)
- tests/test_integration.py (1 hunks)
- tests/test_tabularize.py (1 hunks)
Files skipped from review due to trivial changes (5)
- .pre-commit-config.yaml
- src/MEDS_tabular_automl/configs/default.yaml
- src/MEDS_tabular_automl/configs/tabularization.yaml
- src/MEDS_tabular_automl/configs/tabularization/default.yaml
- src/MEDS_tabular_automl/configs/task_specific_caching.yaml
Additional context used
Ruff
src/MEDS_tabular_automl/scripts/tabularize_time_series.py
8-8: Module level import not at top of file (E402)
9-9: Module level import not at top of file (E402)
10-10: Module level import not at top of file (E402)
12-12: Module level import not at top of file (E402)
13-13: Module level import not at top of file (E402)
14-14: Module level import not at top of file (E402)
15-15: Module level import not at top of file (E402)
17-17: Module level import not at top of file (E402)
18-18: Module level import not at top of file (E402)
19-19: Module level import not at top of file (E402)
20-20: Module level import not at top of file (E402)
21-21: Module level import not at top of file (E402)
22-29: Module level import not at top of file (E402)
src/MEDS_tabular_automl/scripts/tabularize_static.py
13-13: Module level import not at top of file (E402)
15-15: Module level import not at top of file (E402)
17-23: Module level import not at top of file (E402)
24-24: Module level import not at top of file (E402)
25-25: Module level import not at top of file (E402)
26-26: Module level import not at top of file (E402)
27-34: Module level import not at top of file (E402)
src/MEDS_tabular_automl/generate_summarized_reps.py
6-6: Module level import not at top of file (E402)
7-7: Module level import not at top of file (E402)
9-9: Module level import not at top of file (E402)
10-10: Module level import not at top of file (E402)
tests/test_integration.py
5-5: Module level import not at top of file (E402)
6-6: Module level import not at top of file (E402)
7-7: Module level import not at top of file (E402)
8-8: Module level import not at top of file (E402)
9-9: Module level import not at top of file (E402)
11-11: Module level import not at top of file (E402)
12-12: Module level import not at top of file (E402)
13-22: Module level import not at top of file (E402)
24-24: Module level import not at top of file (E402)
25-25: Module level import not at top of file (E402)
26-33: Module level import not at top of file (E402)
tests/test_tabularize.py
5-5: Module level import not at top of file (E402)
6-6: Module level import not at top of file (E402)
7-7: Module level import not at top of file (E402)
8-8: Module level import not at top of file (E402)
9-9: Module level import not at top of file (E402)
11-11: Module level import not at top of file (E402)
12-12: Module level import not at top of file (E402)
14-14: Module level import not at top of file (E402)
15-15: Module level import not at top of file (E402)
16-23: Module level import not at top of file (E402)
24-31: Module level import not at top of file (E402)
LanguageTool
README.md
[uncategorized] ~62-~62: Possible missing comma found. (AI_HYDRA_LEO_MISSING_COMMA)
Context: ...ls and caches them with the event_id indexes which align them with the nearest prior...
[grammar] ~83-~83: It appears that a hyphen is missing in the plural noun “to-dos”? (TO_DO_HYPHEN)
Context: ...bularize.yaml`. ## AutoML Pipelines # TODOs 1. Leverage the "event bound aggregati...
[uncategorized] ~99-~99: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ... Configuration File -MEDS_cohort_dir
: directory of MEDS format dataset that i...
[grammar] ~99-~99: Possible verb agreement error. Did you mean “formats”? (Some collective nouns can be treated as both singular and plural, so ‘format’ is not always incorrect.) (COLLECTIVE_NOUN_VERB_AGREEMENT_VBP)
Context: ... -MEDS_cohort_dir
: directory of MEDS format dataset that is ingested. - `tabularize...
[uncategorized] ~100-~100: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...at is ingested. -tabularized_data_dir
: output directory of tabularized data. -...
[uncategorized] ~101-~101: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...d data. -min_code_inclusion_frequency
: The base feature inclusion frequency th...
[uncategorized] ~106-~106: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...ype inclusion cutoff. -window_sizes
: Beyond writing out a raw, per-event fla...
[uncategorized] ~111-~111: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ... the raw representation files. -codes
: A list of codes to include in the flat ...
[uncategorized] ~113-~113: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...d in the flat representation. -aggs
: A list of aggregations to apply to the ...
[uncategorized] ~114-~114: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...ter than 0. -n_patients_per_sub_shard
: The number of subjects that should be i...
[uncategorized] ~117-~117: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ... more memory efficient. -do_overwrite
: IfTrue
, this function will overwrite...
[uncategorized] ~119-~119: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...in the target save directory. -seed
: The seed to use for random number gener...
Markdownlint
README.md
6-6: Expected: h2; Actual: h4 (MD001, heading-increment)
Heading levels should only increment by one level at a time
45-45: Expected: h2; Actual: h3 (MD001, heading-increment)
Heading levels should only increment by one level at a time
27-27: null (MD025, single-title, single-h1)
Multiple top-level headings in the same document
31-31: null (MD025, single-title, single-h1)
Multiple top-level headings in the same document
83-83: null (MD025, single-title, single-h1)
Multiple top-level headings in the same document
97-97: null (MD025, single-title, single-h1)
Multiple top-level headings in the same document
Additional comments not posted (65)
src/MEDS_tabular_automl/configs/describe_codes.yaml (1)
1-14
: Configuration settings indescribe_codes.yaml
appear consistent and well-structured..github/workflows/tests.yaml (3)
22-25
: Update to Python 3.12 aligns with modern Python practices.
29-29
: Inclusion of specific test dependencies enhances the testing strategy.
36-36
: Ensure thehf_cohort/
directory is correctly ignored for valid reasons.Verification successful
The
hf_cohort/
directory is referenced in theREADME.md
file, indicating it contains scripts for generating task labels and processing datasets. However, the directory itself does not exist in the repository. This suggests that the directory might be dynamically created or is expected to be present in a different environment.
- README.md: References to
hf_cohort/
scripts for generating labels and processing datasets.Given this context, it seems reasonable to exclude the
hf_cohort/
directory from tests, likely because it is not part of the repository or is environment-specific.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the reason for ignoring `hf_cohort/` directory in tests. # Test: Check for specific configurations or data in `hf_cohort/` that justify its exclusion. rg --type yaml --type python "hf_cohort" --files-with-matchesLength of output: 92
Script:
#!/bin/bash # Description: Verify the reason for ignoring `hf_cohort/` directory in tests. # Test: Check for specific configurations or data in `hf_cohort/` that justify its exclusion. rg --type yaml --type py "hf_cohort" --files-with-matchesLength of output: 57
Script:
#!/bin/bash # Description: List the contents of the `hf_cohort/` directory to understand its purpose. # List the contents of the `hf_cohort/` directory. ls -R hf_cohort/Length of output: 76
Script:
#!/bin/bash # Description: Search for references to `hf_cohort` in the entire repository to understand its context. # Search for references to `hf_cohort` in all files. rg "hf_cohort"Length of output: 336
src/MEDS_tabular_automl/file_name.py (1)
5-28
: Enhancements in file management functions appear logically sound and well-implemented.pyproject.toml (3)
6-7
: Addition of new project authors is accurately reflected.Also applies to: 9-9
17-17
: Updated dependencies support new functionalities and requirements.
37-38
: Updated project URLs provide necessary resources and support.src/MEDS_tabular_automl/configs/launch_xgboost.yaml (1)
1-81
: Comprehensive and well-structured configuration for XGBoost model launching.src/MEDS_tabular_automl/configs/tmp.yaml.yaml (1)
1-89
: Temporary configuration settings are comprehensive and align with project standards.src/MEDS_tabular_automl/scripts/sweep_xgboost.py (2)
20-32
: The functiongenerate_permutations
is well-implemented and documented. It correctly generates all possible combinations of a list, which is essential for parameter tuning in XGBoost.
51-82
: The main function is well-structured and effectively sets up and executes the Optuna study for XGBoost optimization.src/MEDS_tabular_automl/scripts/cache_task.py (2)
40-50
: The functiongenerate_row_cached_matrix
correctly checks matrix dimensions and processes the data, ensuring data integrity before transformation.
53-97
: The main function incache_task.py
is well-organized and effectively handles the caching of time-series data, ensuring efficient data processing and storage.src/MEDS_tabular_automl/scripts/describe_codes.py (3)
58-60
: Thecompute_fn
function effectively computes feature frequencies, which is crucial for filtering out infrequent events in the dataset.
61-62
: Thewrite_fn
function correctly handles the writing of dataframes to the output path, ensuring data persistence.
33-112
: The main function indescribe_codes.py
is well-structured and effectively manages the computation and storage of feature frequencies, facilitating efficient data analysis.src/MEDS_tabular_automl/scripts/tabularize_time_series.py (3)
88-89
: Theread_fn
function effectively filters the input data based on resolved codes, ensuring that only relevant data is processed.
91-106
: Thecompute_fn
function intabularize_time_series.py
correctly processes and summarizes time-series data, ensuring accurate data aggregation.
36-123
: The main function is well-organized and effectively orchestrates the data processing pipeline for summarizing time-series data, ensuring efficient and structured data storage.src/MEDS_tabular_automl/generate_ts_features.py (3)
20-23
: The functionfeature_name_to_code
correctly converts feature names to codes, which is essential for processing time series data.
25-38
: Theget_long_code_df
function effectively pivots the codes data frame to a long format one-hot representation, which is crucial for time series analysis.
119-160
: The functionget_flat_ts_rep
ingenerate_ts_features.py
correctly produces a flat time series representation from a given data frame, ensuring accurate feature extraction and summarization.src/MEDS_tabular_automl/scripts/tabularize_static.py (3)
136-137
: Theread_fn
function effectively filters the input data based on resolved codes, ensuring that only relevant data is processed.
139-144
: Thecompute_fn
function intabularize_static.py
correctly processes and summarizes static data, ensuring accurate data aggregation.
41-160
: The main function is well-organized and effectively orchestrates the data processing pipeline for tabularizing static data, ensuring efficient and structured data storage.README.md (4)
Line range hint
7-18
: Excellent explanation of "tabular data" in the context of MEDS datasets. This clarification helps in setting the right context for new users.Tools
Markdownlint
6-6: Expected: h2; Actual: h4 (MD001, heading-increment)
Heading levels should only increment by one level at a time
99-119
: The YAML configuration details are well-documented. Ensure that all default values and options are clearly defined to avoid any confusion during setup.Tools
LanguageTool
[uncategorized] ~99-~99: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ... Configuration File -MEDS_cohort_dir
: directory of MEDS format dataset that i...
[grammar] ~99-~99: Possible verb agreement error. Did you mean “formats”? (Some collective nouns can be treated as both singular and plural, so ‘format’ is not always incorrect.) (COLLECTIVE_NOUN_VERB_AGREEMENT_VBP)
Context: ... -MEDS_cohort_dir
: directory of MEDS format dataset that is ingested. - `tabularize...
[uncategorized] ~100-~100: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...at is ingested. -tabularized_data_dir
: output directory of tabularized data. -...
[uncategorized] ~101-~101: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...d data. -min_code_inclusion_frequency
: The base feature inclusion frequency th...
[uncategorized] ~106-~106: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...ype inclusion cutoff. -window_sizes
: Beyond writing out a raw, per-event fla...
[uncategorized] ~111-~111: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ... the raw representation files. -codes
: A list of codes to include in the flat ...
[uncategorized] ~113-~113: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...d in the flat representation. -aggs
: A list of aggregations to apply to the ...
[uncategorized] ~114-~114: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...ter than 0. -n_patients_per_sub_shard
: The number of subjects that should be i...
[uncategorized] ~117-~117: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ... more memory efficient. -do_overwrite
: IfTrue
, this function will overwrite...
[uncategorized] ~119-~119: Loose punctuation mark. (UNLIKELY_OPENING_PUNCTUATION)
Context: ...in the target save directory. -seed
: The seed to use for random number gener...
47-64
: Ensure the scripts mentioned are correctly linked and accessible in the repository, especially for new users who might follow these as a starting point.Tools
LanguageTool
[uncategorized] ~62-~62: Possible missing comma found. (AI_HYDRA_LEO_MISSING_COMMA)
Context: ...ls and caches them with the event_id indexes which align them with the nearest prior...
21-25
: Investigate the scalability of TemporAI for handling MEDS datasets as mentioned. This could potentially streamline the process if feasible.src/MEDS_tabular_automl/generate_static_features.py (3)
27-42
: The functionconvert_to_matrix
is well-implemented with clear documentation. Ensure that the data types of the matrix elements are consistent with the expected types in downstream processing.
85-154
: The functionsummarize_static_measurements
provides a robust mechanism for aggregating static measurements. Ensure that all edge cases are handled, especially with data types and missing values.
157-182
: The functionget_flat_static_rep
is crucial for generating a flat representation of static data. Verify that the output is correctly integrated into the overall data pipeline and that the data types are consistent.src/MEDS_tabular_automl/describe_codes.py (2)
13-64
: The functioncompute_feature_frequencies
is well-documented and effectively computes feature frequencies. Ensure that the function is tested thoroughly, especially with edge cases involving unusual data distributions.
67-92
: The functionconvert_to_freq_dict
is a key component in converting dataframes to frequency dictionaries. Verify that it handles all possible data types and structures expected in the input dataframes.src/MEDS_tabular_automl/generate_summarized_reps.py (4)
13-26
: The functionsparse_aggregate
is well-implemented with support for multiple aggregation methods. Ensure that the function is tested with various types of sparse matrices to verify its robustness.
80-127
: The functioncompute_agg
applies aggregations to dataframes based on specified window sizes and aggregation methods. Ensure that the function handles all edge cases, especially with non-standard window sizes or aggregation methods.
178-245
: The functiongenerate_summary
is crucial for generating summarized data frames. Verify that all feature columns specified are included in the final output and that default values are correctly applied where necessary.
44-77
: The functionaggregate_matrix
is key to aggregating matrices based on window indices. Verify that the aggregation methods are correctly applied and that the output matrix has the expected dimensions and data types.tests/test_integration.py (6)
36-43
: Functionrun_command
is well-implemented and handles subprocess execution and error reporting effectively.
47-87
: The environment setup in the integration test is correctly implemented, ensuring proper directory and file management.
89-102
: Script execution and output verification fordescribe_codes
are correctly implemented, with appropriate assertions to check the output files and data integrity.
106-153
: Execution and verification of the static data tabularization script are well-handled, with thorough checks for output file existence and data integrity.
161-208
: The time series data tabularization script execution and its output verification are correctly implemented, ensuring data integrity and correct file handling.
210-250
: Setup and execution of the task-specific caching script are correctly implemented, aligning with the overall testing strategy.src/MEDS_tabular_automl/mapper.py (3)
14-53
: Functionget_earliest_lock
is correctly implemented, effectively retrieving the earliest lock from a specified directory. The provided example is clear and demonstrates the function's usage well.
56-82
: Functionregister_lock
is well-implemented, correctly registering a lock file in the specified directory. The example provided effectively demonstrates the function's correct behavior.
85-278
: Functionwrap
is robustly implemented, providing a comprehensive solution for wrapping file-in file-out transformations with caching and locking. The examples are detailed and effectively illustrate the function's capabilities and edge cases.src/MEDS_tabular_automl/scripts/launch_xgboost.py (1)
140-157
: Ensure that_get_code_set
handles edge cases and errors gracefully.Consider adding error handling for cases where configuration settings might lead to unexpected results.
src/MEDS_tabular_automl/utils.py (16)
20-20
: Update topl.LazyFrame
forDF_T
aligns with the need for lazy evaluation in data processing.
21-21
: SettingWRITE_USE_PYARROW
toTrue
ensures efficient data serialization using PyArrow.
22-22
: Introduction ofROW_IDX_NAME
as a constant improves code readability and maintainability.
70-73
: The conversion to a sparse matrix usingcoo_array
is efficient for handling large sparse datasets.
76-77
: Usingnp.result_type
to determine the minimum data type is a good practice for optimizing memory usage.
80-93
: Optimization of data types insparse_matrix_to_array
is crucial for performance, especially in large datasets.
129-137
: Theparse_static_feature_column
function is used effectively here to determine data types based on aggregation. This is a good example of code reuse.
186-216
: Theget_static_feature_cols
function efficiently generates feature columns from static data. The use of list comprehensions and set operations is commendable for performance.
219-248
: Theget_ts_feature_cols
function mirrors the static version but for time-series data. The consistent approach enhances maintainability.
274-300
: Combining static and time-series feature columns inget_flat_rep_feature_cols
is a key functionality for feature engineering in time-series analysis.
303-336
: Theload_meds_data
function is well-documented and handles different data splits efficiently. The use of globbing to find files is a practical approach.
349-358
: Theget_unique_time_events_df
function ensures data integrity by maintaining unique timestamps. The use of assertions is a good practice for debugging.
362-371
: Theget_feature_names
function is a utility that aids in selecting specific feature columns based on aggregation types. This function enhances modularity by separating concerns.
374-378
: Theget_feature_indices
function leverages theget_feature_names
function to map features to their indices, which is crucial for operations that require direct index access.
423-444
: Theget_shard_prefix
function is a utility that extracts relative file paths. This is useful for managing file paths in a consistent manner across different modules.
339-346
: Theget_events_df
function effectively filters and prepares event data. Ensure that the feature columns are correctly specified to avoid runtime issues.Verification successful
The feature columns passed to the
get_events_df
function are correctly specified in the provided instances.
tests/test_tabularize.py
tests/test_integration.py
src/MEDS_tabular_automl/generate_ts_features.py
src/MEDS_tabular_automl/generate_static_features.py
Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify that the feature columns passed to `get_events_df` are correctly specified and exist in the data. # Test: Search for the function usage and check feature column specifications. Expect: Correctly specified feature columns. rg --type python $'get_events_df'Length of output: 63
Script:
#!/bin/bash # Description: Verify that the feature columns passed to `get_events_df` are correctly specified and exist in the data. # Test: Search for the function usage and check feature column specifications. Expect: Correctly specified feature columns. rg --type py $'get_events_df'Length of output: 1352
Script:
#!/bin/bash # Description: Extract lines around the usage of `get_events_df` to verify feature column specifications. # Extract lines around the usage of `get_events_df` in relevant files. rg --type py -A 3 -B 3 $'get_events_df' tests/test_tabularize.py tests/test_integration.py src/MEDS_tabular_automl/generate_ts_features.py src/MEDS_tabular_automl/generate_static_features.pyLength of output: 6560
def _get_code_masks(self, feature_columns: list, codes_set: set) -> Mapping[str, list[bool]]: | ||
"""Create boolean masks for filtering features. | ||
|
||
Creates a dictionary of boolean masks for each aggregation type. The masks are used to filter | ||
the feature columns based on the specified included codes and minimum code inclusion frequency. | ||
|
||
Args: | ||
feature_columns: List of feature columns. | ||
codes_set: Set of codes to include. | ||
|
||
Returns: | ||
Dictionary of code masks for each aggregation. | ||
""" | ||
code_masks = {} | ||
for agg in set(self.cfg.tabularization.aggs): | ||
feature_ids = get_feature_indices(agg, feature_columns) | ||
code_mask = [True if idx in codes_set else False for idx in feature_ids] | ||
code_masks[agg] = code_mask | ||
return code_masks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimize the list comprehension in _get_code_masks
.
Consider using numpy arrays or other efficient data structures for large-scale operations to improve performance.
def _load_matrix(self, path: Path) -> sp.csc_matrix: | ||
"""Load a sparse matrix from disk. | ||
|
||
Args: | ||
- path (Path): Path to the sparse matrix. | ||
|
||
Returns: | ||
- sp.csc_matrix: Sparse matrix. | ||
""" | ||
npzfile = np.load(path) | ||
array, shape = npzfile["array"], npzfile["shape"] | ||
if array.shape[0] != 3: | ||
raise ValueError(f"Expected array to have 3 rows, but got {array.shape[0]} rows") | ||
data, row, col = array | ||
return sp.csc_matrix((data, (row, col)), shape=shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate the loaded matrix shape more robustly in _load_matrix
.
Consider adding more comprehensive checks for the matrix dimensions to ensure data integrity.
def load_labels(self) -> tuple[Mapping[int, list], Mapping[int, list]]: | ||
"""Loads valid event ids and labels for each shard. | ||
|
||
Returns: | ||
- Tuple[Mapping[int, list], Mapping[int, list]]: Tuple containing: | ||
dictionary from shard number to list of valid event ids -- used for indexing rows | ||
in the sparse matrix | ||
dictionary from shard number to list of labels for these valid event ids | ||
""" | ||
label_fps = { | ||
shard: (Path(self.cfg.input_label_dir) / self.split / shard).with_suffix(".parquet") | ||
for shard in self._data_shards | ||
for shard in self._data_shards | ||
} | ||
cached_labels, cached_event_ids = dict(), dict() | ||
for shard, label_fp in label_fps.items(): | ||
label_df = pl.scan_parquet(label_fp) | ||
cached_event_ids[shard] = label_df.select(pl.col("event_id")).collect().to_series() | ||
|
||
# TODO: check this for Nan or any other case we need to worry about | ||
cached_labels[shard] = label_df.select(pl.col("label")).collect().to_series() | ||
if self.cfg.model_params.iterator.binarize_task: | ||
cached_labels[shard] = cached_labels[shard].map_elements( | ||
lambda x: 1 if x > 0 else 0, return_dtype=pl.Int8 | ||
) | ||
|
||
return cached_event_ids, cached_labels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Address the TODO in load_labels
and ensure robust error handling.
Would you like me to help implement the error handling or open a GitHub issue to track this task?
@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem) | ||
def main(cfg: DictConfig) -> float: | ||
"""Optimize the model based on the provided configuration. | ||
|
||
Args: | ||
- cfg (DictConfig): Configuration dictionary. | ||
|
||
Returns: | ||
- float: Evaluation result. | ||
""" | ||
if not cfg.loguru_init: | ||
hydra_loguru_init() | ||
|
||
model = XGBoostModel(cfg) | ||
model.train() | ||
|
||
print( | ||
"Time Profiling for window sizes ", | ||
f"{cfg.tabularization.window_sizes} and min ", | ||
"code frequency of {cfg.tabularization.min_code_inclusion_frequency}:", | ||
) | ||
print("Train Time: \n", model._profile_durations()) | ||
print("Train Iterator Time: \n", model.itrain._profile_durations()) | ||
print("Tuning Iterator Time: \n", model.ituning._profile_durations()) | ||
print("Held Out Iterator Time: \n", model.iheld_out._profile_durations()) | ||
|
||
# save model | ||
save_dir = Path(cfg.output_dir) | ||
save_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
logger.info(f"Saving the model to directory: {save_dir}") | ||
model.model.save_model(save_dir / "model.json") | ||
auc = model.evaluate() | ||
logger.info(f"AUC: {auc}") | ||
return auc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimize the main function in launch_xgboost.py
.
Consider refactoring the main function to separate configuration loading, model training, and evaluation into distinct functions or methods for better modularity and testability.
len(list_subdir_files(Path(cfg.input_dir).parent, "parquet")) == 4 | ||
), "MEDS train split Data Files Should be 4!" | ||
for f in meds_files: | ||
assert pl.read_parquet(f).shape[0] > 0, "MEDS Data Tabular Dataframe Should not be Empty!" | ||
split_json = json.load(StringIO(SPLITS_JSON)) | ||
splits_fp = MEDS_cohort_dir / "splits.json" | ||
json.dump(split_json, splits_fp.open("w")) | ||
# Step 1: Describe Codes - compute code frequencies | ||
describe_codes.main(cfg) | ||
|
||
assert (Path(cfg.output_dir) / "config.yaml").is_file() | ||
assert Path(cfg.output_filepath).is_file() | ||
|
||
feature_columns = get_feature_columns(cfg.output_filepath) | ||
assert get_feature_names("code/count", feature_columns) == sorted(CODE_COLS) | ||
assert get_feature_names("static/present", feature_columns) == sorted(STATIC_PRESENT_COLS) | ||
assert get_feature_names("static/first", feature_columns) == sorted(STATIC_FIRST_COLS) | ||
for value_agg in VALUE_AGGREGATIONS: | ||
assert get_feature_names(value_agg, feature_columns) == sorted(VALUE_COLS) | ||
|
||
# Step 2: Tabularization | ||
tabularize_static_config = { | ||
"MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), | ||
"do_overwrite": False, | ||
"seed": 1, | ||
"hydra.verbose": True, | ||
"tqdm": False, | ||
"loguru_init": True, | ||
"tabularization.min_code_inclusion_frequency": 1, | ||
"tabularization.aggs": "[static/present,static/first,code/count,value/sum]", | ||
"tabularization.window_sizes": "[30d,365d,full]", | ||
} | ||
|
||
with initialize( | ||
version_base=None, config_path="../src/MEDS_tabular_automl/configs/" | ||
): # path to config.yaml | ||
overrides = [f"{k}={v}" for k, v in tabularize_static_config.items()] | ||
cfg = compose(config_name="tabularization", overrides=overrides) # config.yaml | ||
tabularize_static.main(cfg) | ||
output_files = list(Path(cfg.output_dir).glob("**/static/**/*.npz")) | ||
actual_files = [get_shard_prefix(Path(cfg.output_dir), each) + ".npz" for each in output_files] | ||
assert set(actual_files) == set(EXPECTED_STATIC_FILES) | ||
# Check the files are not empty | ||
for f in output_files: | ||
static_matrix = load_matrix(f) | ||
assert static_matrix.shape[0] > 0, "Static Data Tabular Dataframe Should not be Empty!" | ||
expected_num_cols = len(get_feature_names(f"static/{f.stem}", feature_columns)) | ||
assert static_matrix.shape[1] == expected_num_cols, ( | ||
f"Static Data Tabular Dataframe Should have {expected_num_cols}" | ||
f"Columns but has {static_matrix.shape[1]}!" | ||
) | ||
split = f.parts[-5] | ||
shard_num = f.parts[-4] | ||
med_shard_fp = (Path(cfg.input_dir) / split / shard_num).with_suffix(".parquet") | ||
expected_num_rows = ( | ||
get_unique_time_events_df(get_events_df(pl.scan_parquet(med_shard_fp), feature_columns)) | ||
.collect() | ||
.shape[0] | ||
) | ||
assert static_matrix.shape[0] == expected_num_rows, ( | ||
f"Static Data matrix Should have {expected_num_rows}" | ||
f" rows but has {static_matrix.shape[0]}!" | ||
) | ||
allowed_codes = cfg.tabularization._resolved_codes | ||
num_allowed_codes = len(allowed_codes) | ||
feature_columns = get_feature_columns(cfg.tabularization.filtered_code_metadata_fp) | ||
assert num_allowed_codes == len( | ||
feature_columns | ||
), f"Should have {len(feature_columns)} codes but has {num_allowed_codes}" | ||
|
||
tabularize_time_series.main(cfg) | ||
|
||
# confirm summary files exist: | ||
output_files = list_subdir_files(cfg.output_dir, "npz") | ||
actual_files = [ | ||
get_shard_prefix(Path(cfg.output_dir), each) + ".npz" | ||
for each in output_files | ||
if "none/static" not in str(each) | ||
] | ||
assert set(actual_files) == set(SUMMARIZE_EXPECTED_FILES) | ||
for f in output_files: | ||
ts_matrix = load_matrix(f) | ||
assert ts_matrix.shape[0] > 0, "Time-Series Tabular Dataframe Should not be Empty!" | ||
expected_num_cols = len(get_feature_names(f"{f.parent.stem}/{f.stem}", feature_columns)) | ||
assert ts_matrix.shape[1] == expected_num_cols, ( | ||
f"Time-Series Tabular Dataframe Should have {expected_num_cols}" | ||
f"Columns but has {ts_matrix.shape[1]}!" | ||
) | ||
split = f.parts[-5] | ||
shard_num = f.parts[-4] | ||
med_shard_fp = (Path(cfg.input_dir) / split / shard_num).with_suffix(".parquet") | ||
expected_num_rows = ( | ||
get_unique_time_events_df(get_events_df(pl.scan_parquet(med_shard_fp), feature_columns)) | ||
.collect() | ||
.shape[0] | ||
) | ||
assert ts_matrix.shape[0] == expected_num_rows, ( | ||
f"Time-Series Data matrix Should have {expected_num_rows}" | ||
f" rows but has {ts_matrix.shape[0]}!" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor test_tabularize
to improve readability and maintainability.
Consider breaking down the test_tabularize
function into smaller, more focused test functions. This can improve readability and make it easier to identify and address specific failures.
def write_df(df: coo_array, fp: Path, **kwargs): | ||
"""Write shard to disk.""" | ||
do_overwrite = kwargs.get("do_overwrite", False) | ||
|
||
if not do_overwrite and fp.is_file(): | ||
raise FileExistsError(f"{fp} exists and do_overwrite is {do_overwrite}!") | ||
|
||
fp.parent.mkdir(exist_ok=True, parents=True) | ||
|
||
if isinstance(df, pl.LazyFrame): | ||
df.collect().write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW) | ||
elif isinstance(df, pl.DataFrame): | ||
df.write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW) | ||
elif isinstance(df, coo_array): | ||
store_matrix(df, fp) | ||
else: | ||
raise TypeError(f"Unsupported type for df: {type(df)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The write_df
function's flexibility to handle different data types and overwrite behavior is well-implemented. However, ensure that the error messages are informative and guide the user on how to resolve potential issues.
- raise FileExistsError(f"{fp} exists and do_overwrite is {do_overwrite}!")
+ raise FileExistsError(f"File {fp} exists. Set 'do_overwrite' to True to overwrite it.")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def write_df(df: coo_array, fp: Path, **kwargs): | |
"""Write shard to disk.""" | |
do_overwrite = kwargs.get("do_overwrite", False) | |
if not do_overwrite and fp.is_file(): | |
raise FileExistsError(f"{fp} exists and do_overwrite is {do_overwrite}!") | |
fp.parent.mkdir(exist_ok=True, parents=True) | |
if isinstance(df, pl.LazyFrame): | |
df.collect().write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW) | |
elif isinstance(df, pl.DataFrame): | |
df.write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW) | |
elif isinstance(df, coo_array): | |
store_matrix(df, fp) | |
else: | |
raise TypeError(f"Unsupported type for df: {type(df)}") | |
def write_df(df: coo_array, fp: Path, **kwargs): | |
"""Write shard to disk.""" | |
do_overwrite = kwargs.get("do_overwrite", False) | |
if not do_overwrite and fp.is_file(): | |
raise FileExistsError(f"File {fp} exists. Set 'do_overwrite' to True to overwrite it.") | |
fp.parent.mkdir(exist_ok=True, parents=True) | |
if isinstance(df, pl.LazyFrame): | |
df.collect().write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW) | |
elif isinstance(df, pl.DataFrame): | |
df.write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW) | |
elif isinstance(df, coo_array): | |
store_matrix(df, fp) | |
else: | |
raise TypeError(f"Unsupported type for df: {type(df)}") |
def get_prediction_ts_cols( | ||
aggregations: list[str], ts_feature_cols: DF_T, window_sizes: list[str] | None = None | ||
) -> list[str]: | ||
"""Generates a list of feature column names that will be used for downstream task | ||
Examples: | ||
>>> feature_cols = ['A/code', 'A/value', 'C/code', 'C/value'] | ||
>>> window_sizes = None | ||
>>> aggs = ['value/sum', 'code/count'] | ||
>>> get_prediction_ts_cols(aggs, feature_cols, window_sizes) | ||
error | ||
>>> window_sizes = ["1d"] | ||
>>> get_prediction_ts_cols(aggs, feature_cols, window_sizes) | ||
error | ||
""" | ||
agg_feature_columns = [] | ||
for code in ts_feature_cols: | ||
ts_aggregations = [f"{code}/{agg}" for agg in aggregations] | ||
agg_feature_columns.extend(ts_aggregations) | ||
if window_sizes: | ||
ts_aggregations = [f"{window_size}/{code}" for window_size in window_sizes] | ||
return sorted(ts_aggregations) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get_prediction_ts_cols
function should handle the case where window_sizes
is None
to avoid potential errors.
- if window_sizes:
+ if window_sizes is not None:
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def get_prediction_ts_cols( | |
aggregations: list[str], ts_feature_cols: DF_T, window_sizes: list[str] | None = None | |
) -> list[str]: | |
"""Generates a list of feature column names that will be used for downstream task | |
Examples: | |
>>> feature_cols = ['A/code', 'A/value', 'C/code', 'C/value'] | |
>>> window_sizes = None | |
>>> aggs = ['value/sum', 'code/count'] | |
>>> get_prediction_ts_cols(aggs, feature_cols, window_sizes) | |
error | |
>>> window_sizes = ["1d"] | |
>>> get_prediction_ts_cols(aggs, feature_cols, window_sizes) | |
error | |
""" | |
agg_feature_columns = [] | |
for code in ts_feature_cols: | |
ts_aggregations = [f"{code}/{agg}" for agg in aggregations] | |
agg_feature_columns.extend(ts_aggregations) | |
if window_sizes: | |
ts_aggregations = [f"{window_size}/{code}" for window_size in window_sizes] | |
return sorted(ts_aggregations) | |
def get_prediction_ts_cols( | |
aggregations: list[str], ts_feature_cols: DF_T, window_sizes: list[str] | None = None | |
) -> list[str]: | |
"""Generates a list of feature column names that will be used for downstream task | |
Examples: | |
>>> feature_cols = ['A/code', 'A/value', 'C/code', 'C/value'] | |
>>> window_sizes = None | |
>>> aggs = ['value/sum', 'code/count'] | |
>>> get_prediction_ts_cols(aggs, feature_cols, window_sizes) | |
error | |
>>> window_sizes = ["1d"] | |
>>> get_prediction_ts_cols(aggs, feature_cols, window_sizes) | |
error | |
""" | |
agg_feature_columns = [] | |
for code in ts_feature_cols: | |
ts_aggregations = [f"{code}/{agg}" for agg in aggregations] | |
agg_feature_columns.extend(ts_aggregations) | |
if window_sizes is not None: | |
ts_aggregations = [f"{window_size}/{code}" for window_size in window_sizes] | |
return sorted(ts_aggregations) |
def store_config_yaml(config_fp: Path, cfg: DictConfig): | ||
"""Stores configuration parameters into a JSON file. | ||
|
||
This function writes a dictionary of parameters, which includes patient partitioning | ||
information and configuration details, to a specified JSON file. | ||
|
||
Args: | ||
- config_fp (Path): The file path for the JSON file where config should be stored. | ||
- cfg (DictConfig): A configuration object containing settings like the number of patients | ||
per sub-shard, minimum code inclusion frequency, and flags for updating or overwriting existing files. | ||
|
||
Behavior: | ||
- If config_fp exists and cfg.do_overwrite is False (without do_update being True), a | ||
FileExistsError is raised to prevent unintentional data loss. | ||
|
||
Raises: | ||
- ValueError: If there are discrepancies between old and new parameters during an update. | ||
- FileExistsError: If the file exists and overwriting is not allowed. | ||
|
||
Example: | ||
>>> cfg = DictConfig({ | ||
... "n_patients_per_sub_shard": 100, | ||
... "min_code_inclusion_frequency": 5, | ||
... "do_overwrite": True, | ||
... }) | ||
>>> import tempfile | ||
>>> from pathlib import Path | ||
>>> with tempfile.NamedTemporaryFile() as temp_f: | ||
... config_fp = Path(temp_f.name) | ||
... store_config_yaml(config_fp, cfg) | ||
... assert config_fp.exists() | ||
... store_config_yaml(config_fp, cfg) | ||
... cfg.do_overwrite = False | ||
... try: | ||
... store_config_yaml(config_fp, cfg) | ||
... except FileExistsError as e: | ||
... print("FileExistsError Error Triggered") | ||
FileExistsError Error Triggered | ||
""" | ||
OmegaConf.save(cfg, config_fp) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The store_config_yaml
function is robust and handles configuration storage with appropriate error handling. Consider adding a check for do_update
in the behavior description to align with the implementation.
+ - do_update (bool): If True, allows updating existing files even if `do_overwrite` is False.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def store_config_yaml(config_fp: Path, cfg: DictConfig): | |
"""Stores configuration parameters into a JSON file. | |
This function writes a dictionary of parameters, which includes patient partitioning | |
information and configuration details, to a specified JSON file. | |
Args: | |
- config_fp (Path): The file path for the JSON file where config should be stored. | |
- cfg (DictConfig): A configuration object containing settings like the number of patients | |
per sub-shard, minimum code inclusion frequency, and flags for updating or overwriting existing files. | |
Behavior: | |
- If config_fp exists and cfg.do_overwrite is False (without do_update being True), a | |
FileExistsError is raised to prevent unintentional data loss. | |
Raises: | |
- ValueError: If there are discrepancies between old and new parameters during an update. | |
- FileExistsError: If the file exists and overwriting is not allowed. | |
Example: | |
>>> cfg = DictConfig({ | |
... "n_patients_per_sub_shard": 100, | |
... "min_code_inclusion_frequency": 5, | |
... "do_overwrite": True, | |
... }) | |
>>> import tempfile | |
>>> from pathlib import Path | |
>>> with tempfile.NamedTemporaryFile() as temp_f: | |
... config_fp = Path(temp_f.name) | |
... store_config_yaml(config_fp, cfg) | |
... assert config_fp.exists() | |
... store_config_yaml(config_fp, cfg) | |
... cfg.do_overwrite = False | |
... try: | |
... store_config_yaml(config_fp, cfg) | |
... except FileExistsError as e: | |
... print("FileExistsError Error Triggered") | |
FileExistsError Error Triggered | |
""" | |
OmegaConf.save(cfg, config_fp) | |
def store_config_yaml(config_fp: Path, cfg: DictConfig): | |
"""Stores configuration parameters into a JSON file. | |
This function writes a dictionary of parameters, which includes patient partitioning | |
information and configuration details, to a specified JSON file. | |
Args: | |
- config_fp (Path): The file path for the JSON file where config should be stored. | |
- cfg (DictConfig): A configuration object containing settings like the number of patients | |
per sub-shard, minimum code inclusion frequency, and flags for updating or overwriting existing files. | |
- do_update (bool): If True, allows updating existing files even if `do_overwrite` is False. | |
Behavior: | |
- If config_fp exists and cfg.do_overwrite is False (without do_update being True), a | |
FileExistsError is raised to prevent unintentional data loss. | |
Raises: | |
- ValueError: If there are discrepancies between old and new parameters during an update. | |
- FileExistsError: If the file exists and overwriting is not allowed. | |
Example: | |
>>> cfg = DictConfig({ | |
... "n_patients_per_sub_shard": 100, | |
... "min_code_inclusion_frequency": 5, | |
... "do_overwrite": True, | |
... }) | |
>>> import tempfile | |
>>> from pathlib import Path | |
>>> with tempfile.NamedTemporaryFile() as temp_f: | |
... config_fp = Path(temp_f.name) | |
... store_config_yaml(config_fp, cfg) | |
... assert config_fp.exists() | |
... store_config_yaml(config_fp, cfg) | |
... cfg.do_overwrite = False | |
... try: | |
... store_config_yaml(config_fp, cfg) | |
... except FileExistsError as e: | |
... print("FileExistsError Error Triggered") | |
FileExistsError Error Triggered | |
""" | |
OmegaConf.save(cfg, config_fp) |
def xgboost_singleton(trial: optuna.Trial, config: DictConfig) -> float: | ||
for key, value in config.optuna.params.suggest_categorical.items(): | ||
logger.info(f"Optimizing {key} with {value}") | ||
config.tabularization[key] = trial.suggest_categorical(key, value) | ||
for key, value in config.optuna.params.suggest_float.items(): | ||
with open_dict(config): | ||
config[key] = trial.suggest_float(key, **value) | ||
for key, value in config.optuna.params.suggest_int.items(): | ||
with open_dict(config): | ||
config[key] = trial.suggest_int(key, **value) | ||
return launch_xgboost.main(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tip
Codebase Verification
Ensure isolation of configuration changes in xgboost_singleton
to prevent side effects. The search results indicate that the config
object is modified directly in the xgboost_singleton
function. This could potentially lead to unintended side effects if the config
object is shared across different parts of the application.
src/MEDS_tabular_automl/scripts/sweep_xgboost.py: config[key] = trial.suggest_float(key, **value)
src/MEDS_tabular_automl/scripts/sweep_xgboost.py: config[key] = trial.suggest_int(key, **value)
Analysis chain
Ensure isolation of configuration changes in xgboost_singleton
to prevent side effects.
Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify that changes to the configuration in `xgboost_singleton` do not affect other parts of the application.
# Test: Check for any shared state modifications that could lead to unintended side effects.
rg --type python $'config.tabularization[key] = trial.suggest_categorical(key, value)'
Length of output: 118
Script:
#!/bin/bash
# Description: Verify that changes to the configuration in `xgboost_singleton` do not affect other parts of the application.
# Test: Check for any shared state modifications that could lead to unintended side effects.
rg --type py $'config.tabularization[key] = trial.suggest_categorical(key, value)'
Length of output: 83
Script:
#!/bin/bash
# Description: Verify that changes to the configuration in `xgboost_singleton` do not affect other parts of the application.
# Test: Check for any shared state modifications that could lead to unintended side effects.
rg --type py 'config\['
Length of output: 241
def add_static_missing_cols( | ||
flat_df: DF_T, feature_columns: list[str], set_count_0_to_null: bool = False | ||
) -> DF_T: | ||
"""Normalizes columns in a DataFrame so all expected columns are present and appropriately typed. | ||
|
||
Parameters: | ||
- flat_df (DF_T): The DataFrame to be normalized. | ||
- feature_columns (list[str]): A list of feature column names that should exist in the DataFrame. | ||
- set_count_0_to_null (bool): A flag indicating whether counts of zero should be converted to nulls. | ||
|
||
Returns: | ||
- DF_T: The normalized DataFrame with all columns set to the correct type and zero-counts handled | ||
if specified. | ||
|
||
This function ensures that all necessary columns are added and typed correctly within | ||
a DataFrame, potentially modifying zero counts to nulls based on the configuration. | ||
""" | ||
cols_to_add = set(feature_columns) - set(flat_df.columns) | ||
cols_to_retype = set(feature_columns).intersection(set(flat_df.columns)) | ||
|
||
cols_to_add = [(c, get_static_col_dtype(c)) for c in cols_to_add] | ||
cols_to_retype = [(c, get_static_col_dtype(c)) for c in cols_to_retype] | ||
|
||
if "timestamp" in flat_df.columns: | ||
key_cols = ["patient_id", "timestamp"] | ||
else: | ||
key_cols = ["patient_id"] | ||
|
||
flat_df = flat_df.with_columns( | ||
*[pl.lit(None, dtype=dt).alias(c) for c, dt in cols_to_add], | ||
*[pl.col(c).cast(dt).alias(c) for c, dt in cols_to_retype], | ||
).select(*key_cols, *feature_columns) | ||
|
||
if not set_count_0_to_null: | ||
return flat_df | ||
|
||
flat_df = flat_df.collect() | ||
|
||
flat_df = flat_df.with_columns( | ||
pl.when(cs.ends_with("count") != 0).then(cs.ends_with("count")).keep_name() | ||
).lazy() | ||
return flat_df |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The add_static_missing_cols
function is robust in handling missing columns and data types. Consider adding more detailed documentation on the behavior when set_count_0_to_null
is True
.
+ - set_count_0_to_null (bool): If True, converts zero counts to nulls. This can be useful in certain statistical analyses where zero counts are treated differently.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def add_static_missing_cols( | |
flat_df: DF_T, feature_columns: list[str], set_count_0_to_null: bool = False | |
) -> DF_T: | |
"""Normalizes columns in a DataFrame so all expected columns are present and appropriately typed. | |
Parameters: | |
- flat_df (DF_T): The DataFrame to be normalized. | |
- feature_columns (list[str]): A list of feature column names that should exist in the DataFrame. | |
- set_count_0_to_null (bool): A flag indicating whether counts of zero should be converted to nulls. | |
Returns: | |
- DF_T: The normalized DataFrame with all columns set to the correct type and zero-counts handled | |
if specified. | |
This function ensures that all necessary columns are added and typed correctly within | |
a DataFrame, potentially modifying zero counts to nulls based on the configuration. | |
""" | |
cols_to_add = set(feature_columns) - set(flat_df.columns) | |
cols_to_retype = set(feature_columns).intersection(set(flat_df.columns)) | |
cols_to_add = [(c, get_static_col_dtype(c)) for c in cols_to_add] | |
cols_to_retype = [(c, get_static_col_dtype(c)) for c in cols_to_retype] | |
if "timestamp" in flat_df.columns: | |
key_cols = ["patient_id", "timestamp"] | |
else: | |
key_cols = ["patient_id"] | |
flat_df = flat_df.with_columns( | |
*[pl.lit(None, dtype=dt).alias(c) for c, dt in cols_to_add], | |
*[pl.col(c).cast(dt).alias(c) for c, dt in cols_to_retype], | |
).select(*key_cols, *feature_columns) | |
if not set_count_0_to_null: | |
return flat_df | |
flat_df = flat_df.collect() | |
flat_df = flat_df.with_columns( | |
pl.when(cs.ends_with("count") != 0).then(cs.ends_with("count")).keep_name() | |
).lazy() | |
return flat_df | |
def add_static_missing_cols( | |
flat_df: DF_T, feature_columns: list[str], set_count_0_to_null: bool = False | |
) -> DF_T: | |
"""Normalizes columns in a DataFrame so all expected columns are present and appropriately typed. | |
Parameters: | |
- flat_df (DF_T): The DataFrame to be normalized. | |
- feature_columns (list[str]): A list of feature column names that should exist in the DataFrame. | |
- set_count_0_to_null (bool): A flag indicating whether counts of zero should be converted to nulls. | |
- set_count_0_to_null (bool): If True, converts zero counts to nulls. This can be useful in certain statistical analyses where zero counts are treated differently. | |
Returns: | |
- DF_T: The normalized DataFrame with all columns set to the correct type and zero-counts handled | |
if specified. | |
This function ensures that all necessary columns are added and typed correctly within | |
a DataFrame, potentially modifying zero counts to nulls based on the configuration. | |
""" | |
cols_to_add = set(feature_columns) - set(flat_df.columns) | |
cols_to_retype = set(feature_columns).intersection(set(flat_df.columns)) | |
cols_to_add = [(c, get_static_col_dtype(c)) for c in cols_to_add] | |
cols_to_retype = [(c, get_static_col_dtype(c)) for c in cols_to_retype] | |
if "timestamp" in flat_df.columns: | |
key_cols = ["patient_id", "timestamp"] | |
else: | |
key_cols = ["patient_id"] | |
flat_df = flat_df.with_columns( | |
*[pl.lit(None, dtype=dt).alias(c) for c, dt in cols_to_add], | |
*[pl.col(c).cast(dt).alias(c) for c, dt in cols_to_retype], | |
).select(*key_cols, *feature_columns) | |
if not set_count_0_to_null: | |
return flat_df | |
flat_df = flat_df.collect() | |
flat_df = flat_df.with_columns( | |
pl.when(cs.ends_with("count") != 0).then(cs.ends_with("count")).keep_name() | |
).lazy() | |
return flat_df |
def get_sparse_static_rep(static_features, static_df, meds_df, feature_columns) -> coo_array: | ||
"""Merges static and time-series dataframes. | ||
|
||
This function merges the static and time-series dataframes based on the patient_id column. | ||
|
||
Args: | ||
- feature_columns (List[str]): A list of feature columns to include in the merged dataframe. | ||
- static_df (pd.DataFrame): A dataframe containing static features. | ||
- ts_df (pd.DataFrame): A dataframe containing time-series features. | ||
|
||
Returns: | ||
- pd.DataFrame: A merged dataframe containing static and time-series features. | ||
""" | ||
# Make static data sparse and merge it with the time-series data | ||
logger.info("Make static data sparse and merge it with the time-series data") | ||
# Check static_df is sorted and unique | ||
assert static_df.select(pl.col("patient_id")).collect().to_series().is_sorted() | ||
assert ( | ||
static_df.select(pl.len()).collect().item() | ||
== static_df.select(pl.col("patient_id").n_unique()).collect().item() | ||
) | ||
meds_df = get_unique_time_events_df(get_events_df(meds_df, feature_columns)) | ||
|
||
# load static data as sparse matrix | ||
static_matrix = convert_to_matrix( | ||
static_df, num_events=meds_df.select(pl.len()).collect().item(), num_features=len(static_features) | ||
) | ||
# Duplicate static matrix rows to match time-series data | ||
events_per_patient = ( | ||
meds_df.select(pl.col("patient_id").value_counts()) | ||
.unnest("patient_id") | ||
.sort(by="patient_id") | ||
.select(pl.col("count")) | ||
.collect() | ||
.to_series() | ||
) | ||
reindex_slices = np.repeat(range(len(events_per_patient)), events_per_patient) | ||
static_matrix = static_matrix[reindex_slices, :] | ||
return coo_array(static_matrix) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function get_sparse_static_rep
effectively merges static and time-series data. However, consider optimizing the data handling to avoid potential performance bottlenecks with large datasets.
- static_matrix = convert_to_matrix(
- static_df, num_events=meds_df.select(pl.len()).collect().item(), num_features=len(static_features)
- )
+ # Optimize data handling by directly using polars functions to handle large datasets efficiently.
+ static_matrix = static_df.to_coo_array(num_events=meds_df.select(pl.len()).collect().item(), num_features=len(static_features))
Committable suggestion was skipped due to low confidence.
def parse_static_feature_column(c: str) -> tuple[str, str, str, str]: | ||
parts = c.split("/") | ||
if len(parts) < 3: | ||
raise ValueError(f"Column {c} is not a valid flat feature column!") | ||
return ("/".join(parts[:-2]), parts[-2], parts[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tip
Codebase Verification
The inspection of the code surrounding the usage of parse_static_feature_column
indicates that the function is used in contexts where the feature columns are derived from other functions like get_feature_names
. However, there is no explicit validation of the format of these feature columns before they are passed to parse_static_feature_column
.
- src/MEDS_tabular_automl/utils.py: The function
parse_static_feature_column
is used directly without prior validation of the column format. - src/MEDS_tabular_automl/generate_static_features.py: The feature columns are obtained from
get_feature_names
, but there is no explicit check to ensure they are in the expected format before being passed toparse_static_feature_column
.
To prevent runtime errors, it is recommended to add validation checks for the feature column format before they are passed to parse_static_feature_column
.
- src/MEDS_tabular_automl/utils.py
- src/MEDS_tabular_automl/generate_static_features.py
Analysis chain
Ensure proper validation of the feature column format to prevent runtime errors.
Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify that all feature columns passed to `parse_static_feature_column` are in the expected format.
# Test: Search for the function usage and ensure proper format. Expect: Properly formatted feature columns.
rg --type python $'parse_static_feature_column'
Length of output: 77
Script:
#!/bin/bash
# Description: Verify that all feature columns passed to `parse_static_feature_column` are in the expected format.
# Test: Search for the function usage and ensure proper format. Expect: Properly formatted feature columns.
rg --type py 'parse_static_feature_column'
Length of output: 609
Script:
#!/bin/bash
# Description: Extract and inspect the relevant lines of code around the usages of `parse_static_feature_column` to verify the format of the feature columns being passed.
# Extract lines around the usage in `src/MEDS_tabular_automl/utils.py`
rg --type py -A 3 -B 3 'parse_static_feature_column' src/MEDS_tabular_automl/utils.py
# Extract lines around the usage in `src/MEDS_tabular_automl/generate_static_features.py`
rg --type py -A 3 -B 3 'parse_static_feature_column' src/MEDS_tabular_automl/generate_static_features.py
Length of output: 1705
def filter_to_codes( | ||
allowed_codes: list[str] | None, | ||
min_code_inclusion_frequency: int, | ||
code_metadata_fp: Path, | ||
): | ||
"""Returns intersection of allowed codes if they are specified, and filters to codes based on inclusion | ||
frequency.""" | ||
if allowed_codes is None: | ||
allowed_codes = get_feature_columns(code_metadata_fp) | ||
feature_freqs = get_feature_freqs(code_metadata_fp) | ||
|
||
code_freqs = { | ||
code: freq | ||
for code, freq in feature_freqs.items() | ||
if (freq >= min_code_inclusion_frequency and code in set(allowed_codes)) | ||
} | ||
return sorted([code for code, freq in code_freqs.items() if freq >= min_code_inclusion_frequency]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function filter_to_codes
effectively filters codes based on inclusion criteria. Consider adding more detailed logging to trace the filtering process, which could be helpful for debugging and maintenance.
+ logger.debug("Filtering codes with criteria: allowed_codes={}, min_code_inclusion_frequency={}", allowed_codes, min_code_inclusion_frequency)
Committable suggestion was skipped due to low confidence.
def filter_parquet(fp, allowed_codes: list[str]): | ||
"""Loads Parquet with Polars and filters to allowed codes. | ||
|
||
Args: | ||
fp: Path to the Meds cohort shard | ||
allowed_codes: List of codes to filter to. | ||
|
||
Expect: | ||
>>> from tempfile import NamedTemporaryFile | ||
>>> fp = NamedTemporaryFile() | ||
>>> pl.DataFrame({ | ||
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"], | ||
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"], | ||
... "numerical_value": [1, None, 2, 2, None, 5, None, 3] | ||
... }).write_parquet(fp.name) | ||
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect() | ||
shape: (6, 3) | ||
┌──────┬────────────┬─────────────────┐ | ||
│ code ┆ timestamp ┆ numerical_value │ | ||
│ --- ┆ --- ┆ --- │ | ||
│ str ┆ str ┆ i64 │ | ||
╞══════╪════════════╪═════════════════╡ | ||
│ A ┆ 2021-01-01 ┆ null │ | ||
│ A ┆ 2021-01-01 ┆ null │ | ||
│ D ┆ null ┆ null │ | ||
│ D ┆ null ┆ null │ | ||
│ E ┆ 2021-01-03 ┆ null │ | ||
│ E ┆ 2021-01-04 ┆ 3 │ | ||
└──────┴────────────┴─────────────────┘ | ||
>>> fp.close() | ||
""" | ||
df = pl.scan_parquet(fp) | ||
# Drop values that are rare | ||
# Drop Rare Static Codes | ||
static_present_feature_columns = [ | ||
clear_code_aggregation_suffix(each) for each in get_feature_names("static/present", allowed_codes) | ||
] | ||
static_first_feature_columns = [ | ||
clear_code_aggregation_suffix(each) for each in get_feature_names("static/first", allowed_codes) | ||
] | ||
code_feature_columns = [ | ||
clear_code_aggregation_suffix(each) for each in get_feature_names("code/count", allowed_codes) | ||
] | ||
value_feature_columns = [ | ||
clear_code_aggregation_suffix(each) for each in get_feature_names("value/sum", allowed_codes) | ||
] | ||
|
||
is_static_code = pl.col("timestamp").is_null() | ||
is_numeric_code = pl.col("numerical_value").is_not_null() | ||
rare_static_code = is_static_code & ~pl.col("code").is_in(static_present_feature_columns) | ||
rare_ts_code = ~is_static_code & ~pl.col("code").is_in(code_feature_columns) | ||
rare_ts_value = ~is_static_code & ~pl.col("code").is_in(value_feature_columns) & is_numeric_code | ||
rare_static_value = is_static_code & ~pl.col("code").is_in(static_first_feature_columns) & is_numeric_code | ||
|
||
# Remove rare numeric values by converting them to null | ||
df = df.with_columns( | ||
pl.when(rare_static_value | rare_ts_value) | ||
.then(None) | ||
.otherwise(pl.col("numerical_value")) | ||
.alias("numerical_value") | ||
) | ||
# Drop rows with rare codes | ||
df = df.filter(~(rare_static_code | rare_ts_code)) | ||
return df |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function filter_parquet
is crucial for filtering Parquet files based on allowed codes. Ensure that the filtering logic is optimized to handle large datasets efficiently.
- df = df.filter(~(rare_static_code | rare_ts_code))
+ # Optimize filtering logic for large datasets.
+ df = df.lazy().filter(~(rare_static_code | rare_ts_code)).collect()
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def filter_parquet(fp, allowed_codes: list[str]): | |
"""Loads Parquet with Polars and filters to allowed codes. | |
Args: | |
fp: Path to the Meds cohort shard | |
allowed_codes: List of codes to filter to. | |
Expect: | |
>>> from tempfile import NamedTemporaryFile | |
>>> fp = NamedTemporaryFile() | |
>>> pl.DataFrame({ | |
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"], | |
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"], | |
... "numerical_value": [1, None, 2, 2, None, 5, None, 3] | |
... }).write_parquet(fp.name) | |
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect() | |
shape: (6, 3) | |
┌──────┬────────────┬─────────────────┐ | |
│ code ┆ timestamp ┆ numerical_value │ | |
│ --- ┆ --- ┆ --- │ | |
│ str ┆ str ┆ i64 │ | |
╞══════╪════════════╪═════════════════╡ | |
│ A ┆ 2021-01-01 ┆ null │ | |
│ A ┆ 2021-01-01 ┆ null │ | |
│ D ┆ null ┆ null │ | |
│ D ┆ null ┆ null │ | |
│ E ┆ 2021-01-03 ┆ null │ | |
│ E ┆ 2021-01-04 ┆ 3 │ | |
└──────┴────────────┴─────────────────┘ | |
>>> fp.close() | |
""" | |
df = pl.scan_parquet(fp) | |
# Drop values that are rare | |
# Drop Rare Static Codes | |
static_present_feature_columns = [ | |
clear_code_aggregation_suffix(each) for each in get_feature_names("static/present", allowed_codes) | |
] | |
static_first_feature_columns = [ | |
clear_code_aggregation_suffix(each) for each in get_feature_names("static/first", allowed_codes) | |
] | |
code_feature_columns = [ | |
clear_code_aggregation_suffix(each) for each in get_feature_names("code/count", allowed_codes) | |
] | |
value_feature_columns = [ | |
clear_code_aggregation_suffix(each) for each in get_feature_names("value/sum", allowed_codes) | |
] | |
is_static_code = pl.col("timestamp").is_null() | |
is_numeric_code = pl.col("numerical_value").is_not_null() | |
rare_static_code = is_static_code & ~pl.col("code").is_in(static_present_feature_columns) | |
rare_ts_code = ~is_static_code & ~pl.col("code").is_in(code_feature_columns) | |
rare_ts_value = ~is_static_code & ~pl.col("code").is_in(value_feature_columns) & is_numeric_code | |
rare_static_value = is_static_code & ~pl.col("code").is_in(static_first_feature_columns) & is_numeric_code | |
# Remove rare numeric values by converting them to null | |
df = df.with_columns( | |
pl.when(rare_static_value | rare_ts_value) | |
.then(None) | |
.otherwise(pl.col("numerical_value")) | |
.alias("numerical_value") | |
) | |
# Drop rows with rare codes | |
df = df.filter(~(rare_static_code | rare_ts_code)) | |
return df | |
# Drop rows with rare codes | |
# Optimize filtering logic for large datasets. | |
df = df.lazy().filter(~(rare_static_code | rare_ts_code)).collect() | |
return df |
def get_rolling_window_indicies(index_df, window_size): | ||
"""Get the indices for the rolling windows.""" | ||
if window_size == "full": | ||
timedelta = pd.Timedelta(150 * 52, unit="W") # just use 150 years as time delta | ||
else: | ||
timedelta = pd.Timedelta(window_size) | ||
return ( | ||
index_df.with_row_index("index") | ||
.rolling(index_column="timestamp", period=timedelta, group_by="patient_id") | ||
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")]) | ||
.select(pl.col("min_index", "max_index")) | ||
.collect() | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function get_rolling_window_indicies
effectively computes indices for rolling windows. Consider adding error handling for cases where the window_size
parameter is not properly formatted.
+ if not isinstance(window_size, str):
+ raise ValueError("window_size must be a string representing a time delta.")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def get_rolling_window_indicies(index_df, window_size): | |
"""Get the indices for the rolling windows.""" | |
if window_size == "full": | |
timedelta = pd.Timedelta(150 * 52, unit="W") # just use 150 years as time delta | |
else: | |
timedelta = pd.Timedelta(window_size) | |
return ( | |
index_df.with_row_index("index") | |
.rolling(index_column="timestamp", period=timedelta, group_by="patient_id") | |
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")]) | |
.select(pl.col("min_index", "max_index")) | |
.collect() | |
) | |
def get_rolling_window_indicies(index_df, window_size): | |
"""Get the indices for the rolling windows.""" | |
if not isinstance(window_size, str): | |
raise ValueError("window_size must be a string representing a time delta.") | |
if window_size == "full": | |
timedelta = pd.Timedelta(150 * 52, unit="W") # just use 150 years as time delta | |
else: | |
timedelta = pd.Timedelta(window_size) | |
return ( | |
index_df.with_row_index("index") | |
.rolling(index_column="timestamp", period=timedelta, group_by="patient_id") | |
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")]) | |
.select(pl.col("min_index", "max_index")) | |
.collect() | |
) |
Merging Functional Pipeline with CLI to Main
Summary by CodeRabbit
New Features
Documentation
Chores
pyproject.toml
.Refactor
Bug Fixes