-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding memory efficient updates #17
Conversation
… is slightly slower now, debugging this, but memory usage is stable
WalkthroughThe recent updates enhance the functionality and efficiency of the Changes
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configration File (
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #17 +/- ##
==========================================
- Coverage 38.64% 37.80% -0.84%
==========================================
Files 14 14
Lines 942 968 +26
==========================================
+ Hits 364 366 +2
- Misses 578 602 +24 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- src/MEDS_tabular_automl/generate_summarized_reps.py (3 hunks)
- src/MEDS_tabular_automl/generate_ts_features.py (2 hunks)
Additional context used
Ruff
src/MEDS_tabular_automl/generate_summarized_reps.py
6-6: Module level import not at top of file (E402)
7-7: Module level import not at top of file (E402)
9-9: Module level import not at top of file (E402)
10-10: Module level import not at top of file (E402)
11-16: Module level import not at top of file (E402)
Additional comments not posted (2)
src/MEDS_tabular_automl/generate_ts_features.py (1)
95-104
: Ensure the sparse matrix generation handles different aggregation types correctly.Verification successful
The function
summarize_dynamic_measurements
handles different aggregation types (CODE_AGGREGATIONS
andVALUE_AGGREGATIONS
) and generates a sparse matrix accordingly. The implementation appears to correctly differentiate between these aggregation types and processes them using appropriate helper functions (get_long_code_df
andget_long_value_df
).
- The function checks if
agg
is inCODE_AGGREGATIONS
orVALUE_AGGREGATIONS
.- For
CODE_AGGREGATIONS
, it drops thenumerical_value
column and processes the data usingget_long_code_df
.- For
VALUE_AGGREGATIONS
, it drops the ID columns and processes the data usingget_long_value_df
.- The sparse matrix is then created using the processed data and the appropriate shape.
The sparse matrix generation seems to handle different aggregation types correctly based on the provided implementation.
Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify that all aggregation types are correctly handled in the sparse matrix generation. # Test: Search for the function usage with different aggregation types. Expect: Correct handling for each type. rg --type python $'summarize_dynamic_measurements'Length of output: 80
Script:
#!/bin/bash # Description: Verify that all aggregation types are correctly handled in the sparse matrix generation. # Test: Search for the function definition and usage of `summarize_dynamic_measurements`. rg 'def summarize_dynamic_measurements' -A 10 rg 'summarize_dynamic_measurements' -A 5Length of output: 1770
Script:
#!/bin/bash # Description: Extract the full implementation of `summarize_dynamic_measurements` and its usage context. # Extract the full implementation of `summarize_dynamic_measurements`. rg 'def summarize_dynamic_measurements' -A 50 # Extract the surrounding context of its usage. rg 'return summarize_dynamic_measurements' -A 10 -B 10Length of output: 5484
src/MEDS_tabular_automl/generate_summarized_reps.py (1)
19-27
: Ensure thesparse_aggregate
function correctly handles different aggregation methods.Verification successful
The
sparse_aggregate
function correctly handles different aggregation methods, including "sum", "min", "max", "sum_sqd", and "count". Theagg
variable covers all these methods, ensuring the function's correctness.
agg
is assigned values that include "sum", "min", "max", "sum_sqd", and "count".Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify that all aggregation methods are correctly implemented in the `sparse_aggregate` function. # Test: Search for the function usage with different aggregation methods. Expect: Correct implementation for each method. rg --type python $'sparse_aggregate'Length of output: 66
Script:
#!/bin/bash # Description: Verify that all aggregation methods are correctly implemented in the `sparse_aggregate` function. # Test: Search for the function usage with different aggregation methods. Expect: Correct implementation for each method. rg --type py 'sparse_aggregate'Length of output: 378
Script:
#!/bin/bash # Description: Extract and review the lines of code where `sparse_aggregate` is called to verify the usage of different aggregation methods. # Extract lines where `sparse_aggregate` is called. rg --type py 'sparse_aggregate' -A 3 -B 3Length of output: 1937
Script:
#!/bin/bash # Description: Extract specific lines where `sparse_aggregate` is called to verify the usage of different aggregation methods. # Extract specific lines where `sparse_aggregate` is called. rg --type py 'sparse_aggregate' -nLength of output: 390
Script:
#!/bin/bash # Description: Extract lines around 36 and 79 in `generate_summarized_reps.py` to verify the aggregation methods used in `sparse_aggregate` calls. # Extract lines around 36 and 79 in `generate_summarized_reps.py`. sed -n '33,39p;76,82p' src/MEDS_tabular_automl/generate_summarized_reps.pyLength of output: 714
Script:
#!/bin/bash # Description: Search for the definitions or assignments of `value` and `agg` in the `generate_summarized_reps.py` file to verify the aggregation methods used in `sparse_aggregate` calls. # Search for the definitions or assignments of `value`. rg --type py 'value =' src/MEDS_tabular_automl/generate_summarized_reps.py # Search for the definitions or assignments of `agg`. rg --type py 'agg =' src/MEDS_tabular_automl/generate_summarized_reps.pyLength of output: 351
x = df.with_columns( | ||
pl.col("code").cast(str).replace(column_to_int).cast(int).alias("code_index"), | ||
pl.lit(1).alias("count"), | ||
).drop("code") | ||
# sum up counts for same patient_id, timestamp, code_index | ||
x = x.group_by("patient_id", "timestamp", "code_index").sum() | ||
# combine codes and counts for same patient_id, timestamp | ||
x = x.group_by("patient_id", "timestamp", maintain_order=True).agg(pl.col("code_index", "count")) | ||
|
||
# repeat row_index for each code_index on that row (i.e. 1 row == 1 unique patient_id x timestamp) | ||
rows = ( | ||
x.with_row_index("row_index") | ||
.select(pl.col("row_index").repeat_by(pl.col("code_index").list.len())) | ||
.select(pl.col("row_index").explode()) | ||
.collect() | ||
.to_series() | ||
.to_numpy() | ||
) | ||
assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type" | ||
data = np.ones(df.select(pl.len()).collect().item(), dtype=np.bool_) | ||
return data, (rows, cols) | ||
.T | ||
)[0] | ||
cols = x.select(pl.col("code_index").explode()).collect().to_numpy().T[0] | ||
data = x.select(pl.col("count").explode()).collect().to_numpy().T[0] | ||
shape = (x.select(pl.len()).collect().item(), len(ts_columns)) | ||
return data, (rows, cols), shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor the get_long_code_df
function to improve clarity and efficiency.
- x = df.with_columns(
- pl.col("code").cast(str).replace(column_to_int).cast(int).alias("code_index"),
- pl.lit(1).alias("count"),
- ).drop("code")
- x = x.group_by("patient_id", "timestamp", "code_index").sum()
- x = x.group_by("patient_id", "timestamp", maintain_order=True).agg(pl.col("code_index", "count"))
- rows = (
- x.with_row_index("row_index")
- .select(pl.col("row_index").repeat_by(pl.col("code_index").list.len()))
- .select(pl.col("row_index").explode())
- .collect()
- .to_numpy()
- .T
- )[0]
- cols = x.select(pl.col("code_index").explode()).collect().to_numpy().T[0]
- data = x.select(pl.col("count").explode()).collect().to_numpy().T[0]
- shape = (x.select(pl.len()).collect().item(), len(ts_columns))
+ # Simplify the transformation and aggregation steps
+ x = df.with_columns(
+ pl.col("code").cast(str).replace(column_to_int).cast(int).alias("code_index"),
+ pl.lit(1).alias("count")
+ ).drop("code")
+ .group_by("patient_id", "timestamp", "code_index").agg(pl.sum("count").alias("count"))
+ .group_by("patient_id", "timestamp", maintain_order=True).agg(pl.list("code_index").alias("code_indices"), pl.list("count").alias("counts"))
+
+ # Flatten the lists to get the final rows, cols, and data arrays
+ rows = x.select(pl.explode(pl.col("code_indices").repeat_by(pl.col("counts")))).to_numpy().flatten()
+ cols = x.select(pl.explode(pl.col("code_indices"))).to_numpy().flatten()
+ data = x.select(pl.explode(pl.col("counts"))).to_numpy().flatten()
+ shape = (x.select(pl.len()).collect().item(), len(ts_columns))
This refactoring simplifies the data transformation and aggregation steps, making the code easier to understand and potentially more efficient by reducing the number of operations and intermediate data structures.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
x = df.with_columns( | |
pl.col("code").cast(str).replace(column_to_int).cast(int).alias("code_index"), | |
pl.lit(1).alias("count"), | |
).drop("code") | |
# sum up counts for same patient_id, timestamp, code_index | |
x = x.group_by("patient_id", "timestamp", "code_index").sum() | |
# combine codes and counts for same patient_id, timestamp | |
x = x.group_by("patient_id", "timestamp", maintain_order=True).agg(pl.col("code_index", "count")) | |
# repeat row_index for each code_index on that row (i.e. 1 row == 1 unique patient_id x timestamp) | |
rows = ( | |
x.with_row_index("row_index") | |
.select(pl.col("row_index").repeat_by(pl.col("code_index").list.len())) | |
.select(pl.col("row_index").explode()) | |
.collect() | |
.to_series() | |
.to_numpy() | |
) | |
assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type" | |
data = np.ones(df.select(pl.len()).collect().item(), dtype=np.bool_) | |
return data, (rows, cols) | |
.T | |
)[0] | |
cols = x.select(pl.col("code_index").explode()).collect().to_numpy().T[0] | |
data = x.select(pl.col("count").explode()).collect().to_numpy().T[0] | |
shape = (x.select(pl.len()).collect().item(), len(ts_columns)) | |
return data, (rows, cols), shape | |
# Simplify the transformation and aggregation steps | |
x = df.with_columns( | |
pl.col("code").cast(str).replace(column_to_int).cast(int).alias("code_index"), | |
pl.lit(1).alias("count") | |
).drop("code") | |
.group_by("patient_id", "timestamp", "code_index").agg(pl.sum("count").alias("count")) | |
.group_by("patient_id", "timestamp", maintain_order=True).agg(pl.list("code_index").alias("code_indices"), pl.list("count").alias("counts")) | |
# Flatten the lists to get the final rows, cols, and data arrays | |
rows = x.select(pl.explode(pl.col("code_indices").repeat_by(pl.col("counts")))).to_numpy().flatten() | |
cols = x.select(pl.explode(pl.col("code_indices"))).to_numpy().flatten() | |
data = x.select(pl.explode(pl.col("counts"))).to_numpy().flatten() | |
shape = (x.select(pl.len()).collect().item(), len(ts_columns)) | |
return data, (rows, cols), shape |
def sparse_assign(row_index, index, data, row, col, sparse_matrix, value, dtype): | ||
agg_matrix = sparse_aggregate(sparse_matrix, value, dtype) | ||
if isinstance(agg_matrix, np.ndarray): | ||
nozero_ind = agg_matrix.nonzero()[0] | ||
len_data = len(nozero_ind) | ||
col[index : index + len_data] = nozero_ind | ||
data[index : index + len_data] = agg_matrix[nozero_ind] | ||
row[index : index + len_data] = row_index | ||
elif isinstance(agg_matrix, coo_array): | ||
len_data = len(agg_matrix.data) | ||
col[index : index + len_data] = agg_matrix.col | ||
data[index : index + len_data] = agg_matrix.data | ||
row[index : index + len_data] = row_index | ||
else: | ||
raise TypeError(f"Invalid matrix type {type(agg_matrix)}") | ||
return len_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor the sparse_assign
function to improve clarity and efficiency.
- agg_matrix = sparse_aggregate(sparse_matrix, value, dtype)
- if isinstance(agg_matrix, np.ndarray):
- nozero_ind = agg_matrix.nonzero()[0]
- len_data = len(nozero_ind)
- col[index : index + len_data] = nozero_ind
- data[index : index + len_data] = agg_matrix[nozero_ind]
- row[index : index + len_data] = row_index
- elif isinstance(agg_matrix, coo_array):
- len_data = len(agg_matrix.data)
- col[index : index + len_data] = agg_matrix.col
- data[index : index + len_data] = agg_matrix.data
- row[index : index + len_data] = row_index
- else:
- raise TypeError(f"Invalid matrix type {type(agg_matrix)}")
+ # Simplify the handling of different matrix types
+ agg_matrix = sparse_aggregate(sparse_matrix, value, dtype)
+ len_data = len(agg_matrix.data)
+ col[index : index + len_data] = agg_matrix.col
+ data[index : index + len_data] = agg_matrix.data
+ row[index : index + len_data] = row_index
This refactoring simplifies the handling of different matrix types, making the code easier to understand and potentially more efficient by reducing the number of conditional branches.
Committable suggestion was skipped due to low confidence.
@Oufattole is this going to be merged or should we close it unmerged? I'd rather not have it hanging around indefinitely. |
Closing as this has gone a long time without update from the original code author. |
Summary by CodeRabbit
New Features
Refactor
Bug Fixes