Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes the slowdowns and bugs caused by the prior improved compute practices, but requires a nested tensor package. #90

Merged
merged 30 commits into from
May 31, 2024

Conversation

mmcdermott
Copy link
Owner

@mmcdermott mmcdermott commented Dec 20, 2023

This should fix #73 as well

@juancq
Copy link
Contributor

juancq commented Dec 27, 2023

@mmcdermott what polars and numpy version were used for this branch?

Building the dataset works fine with this branch.

When running pretrain, I initially got a runtime error stemming from here:

dense_tensors = {
k: np.array(data_as_lists[k], dtype=tensor_types.get(k, np.float32)) for k in dense_keys
}

The error being:
ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (55877,) + inhomogeneous part.

When I change the numpy array dtype to object, it just hangs (it runs for an incredibly long time, and I haven't been patient enough to wait it out). I don't know whether safetensors or the nested_ragged_tensors are to blame (I assume the latter). This occurs when trying to run pretrain, sampling about 90,000 subjects from a full dataset of about 600,000.

@mmcdermott
Copy link
Owner Author

@juancq , definitely don't feel obliged to wait it out when setting it to an object; that will be very slow and defeats the purpose of this change. One question, though; in your data do you have reason to believe that your subjects will have varying numbers of observations of static measurements per row? I believe that what is going on here is that something that is treated as a dense tensor is in reality a ragged one.

@juancq
Copy link
Contributor

juancq commented Jan 16, 2024

@juancq , definitely don't feel obliged to wait it out when setting it to an object; that will be very slow and defeats the purpose of this change. One question, though; in your data do you have reason to believe that your subjects will have varying numbers of observations of static measurements per row? I believe that what is going on here is that something that is treated as a dense tensor is in reality a ragged one.

I don't follow, what would be an example?

This branch also has parts that break with certain polars versions. This makes it hard to test because I don't know which exact polar version to use.

@juancq
Copy link
Contributor

juancq commented Feb 8, 2024

@mmcdermott I have sorted out polars issues and have gotten further testing this.

My pretraining now hangs here:

sparse_dynamic_tensors = full_sparse_dynamic_tensors[subject_idx]

When I kill the script, the stack trace is something along the following lines (in ragged tensors):

get__item https://github.com/mmcdermott/nested_ragged_tensors/blob/b3a01a661e73e9c55541430dbb985c51d054fa4d/src/nested_ragged_tensors/ragged_numpy.py#L398

vstack https://github.com/mmcdermott/nested_ragged_tensors/blob/b3a01a661e73e9c55541430dbb985c51d054fa4d/src/nested_ragged_tensors/ragged_numpy.py#L655

concatenate https://github.com/mmcdermott/nested_ragged_tensors/blob/b3a01a661e73e9c55541430dbb985c51d054fa4d/src/nested_ragged_tensors/ragged_numpy.py#L754

If this is too cryptic, let me know and I'll rephrase or can post an issue on nested_ragged_tensors repo.

@mmcdermott
Copy link
Owner Author

Thank you @juancq -- I've been travelling quite a bit and been otherwise occupied for the last month and a half, but I'm trying to push a new major version of ESGPT that addresses these issues and the other memory issues. I assume your last comment still reflects the state of things with this change for you?

@mmcdermott
Copy link
Owner Author

@coderabbitai review

Copy link

coderabbitai bot commented Mar 22, 2024

Walkthrough

The recent changes in the EventStream project involve enhancing data handling, caching mechanisms, and error management across multiple files. These updates include refining file extension checks, improving exception handling for data conversion, restructuring caching mechanisms for efficiency, and aligning test cases with the updated data structures and logic.

Changes

Files Change Summary
EventStream/data/config.py Modified tensorized_cached_files to use a dictionary comprehension with a different file extension check.
EventStream/data/dataset_polars.py Added defaultdict import, improved exception handling in _filter_col_inclusion, and updated build_DL_cached_representation for new aggregation of time_delta.
EventStream/data/pytorch_dataset.py Restructured caching mechanisms, introduced new tensor structures for caching, revamped logic for handling dense and ragged tensors, and updated caching process for improved efficiency.
EventStream/baseline/FT_task_baseline.py Renamed ConstructorPytorchDataset to PytorchDataset, affecting task normalization.
EventStream/data/dataset_base.py Added imports for polars and JointNestedRaggedTensorDict, modified caching logic for DL representations, handled sharding, and cached NRT representations using Polars.
tests/data/test_pytorch_dataset.py Removed unnecessary imports, updated references and data structures, refactored methods for handling temporary directories, and revised test cases to align with the new data structures and logic.
EventStream/data/types.py Changed representation of null values from "null" to "nul" in convert_to_DL_DF.

Assessment against linked issues

Objective (Issue #73) Addressed Explanation
Memory consumption increase with DataLoader and num_workers > 0 The changes include significant restructuring of caching mechanisms and data handling, which might indirectly address memory consumption issues. However, no direct fix for the issue is evident.

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger a review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Status

Actionable comments generated: 2

Configuration used: CodeRabbit UI

Commits Files that changed from the base of the PR and between 9a1d72d and fa33387.
Files ignored due to path filters (1)
  • pyproject.toml is excluded by: !**/*.toml
Files selected for processing (3)
  • EventStream/data/config.py (1 hunks)
  • EventStream/data/dataset_polars.py (2 hunks)
  • EventStream/data/pytorch_dataset.py (7 hunks)
Additional comments: 11
EventStream/data/pytorch_dataset.py (9)
  • 9-21: > 📝 NOTE

This review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [1-17]

The imports are appropriate and align with the functionality implemented in this file. The use of nested_ragged_tensors is particularly noteworthy, as it directly relates to the PR's objective of handling ragged tensors more efficiently.

  • 77-77: Ensure that len(self.sparse_static_tensors) correctly reflects the dataset's length. If sparse_static_tensors is expected to be a complete representation of the dataset, this is fine. Otherwise, consider a more robust way to determine the dataset's length.
  • 133-223: The caching logic in _cache_subset and the subsequent saving of tensors and data stats are well-implemented. However, consider the following improvements for readability and efficiency:
  • Extracting repeated code blocks into helper functions.
  • Using more descriptive variable names for clarity.
  • Adding comments to complex logic sections for better maintainability.
  • 244-289: The construction and saving of dense and ragged tensors in _cache_full_data are crucial for the dataset's performance. Ensure that:
  • The data types and shapes of tensors are correctly handled, especially when converting lists to numpy arrays and tensors.
  • The use of JointNestedRaggedTensorDict for ragged tensors aligns with the expected data structure and performance requirements.
  • 292-305: The method fetch_tensors efficiently loads the cached tensors and applies necessary configurations. However, consider the following:
  • Ensure that the deletion of tensors (start_time, start_idx, end_idx, subject_id) based on configuration flags does not inadvertently affect downstream processing.
  • Validate that all necessary tensors are loaded and correctly handled for different configurations.
  • 353-357: The __getitem__ method's implementation appears correct. Ensure that the slicing and indexing logic correctly handles edge cases, especially for datasets with varying sequence lengths or missing data.
  • 373-393: The collate method's handling of ragged tensors and conversion to dense tensors is critical. Verify that:
  • The conversion logic (torch.from_numpy(v.astype(int)).long(), torch.from_numpy(v.astype(float)).float()) correctly handles all data types present in the dataset.
  • The handling of masks (event_mask, dynamic_values_mask) correctly identifies valid and missing data points.
  • 401-408: In the collate method, ensure that the conversion of tensors to the appropriate types (float, long) and the handling of nan values (torch.nan_to_num) are correctly applied to all relevant tensors. This is crucial for maintaining data integrity and preventing errors in downstream processing.
  • 244-289: The __getitem__ and collate methods in ConstructorPytorchDataset are crucial for data loading and preprocessing. Ensure that:
  • The logic for handling static and dynamic data, including padding and type conversion, is correctly implemented.
  • The handling of task labels (multi_class_classification, binary_classification, regression) is robust and correctly converts labels to the appropriate tensor types.
EventStream/data/config.py (1)
  • 1040-1040: The modification to use a dictionary comprehension for file retrieval seems efficient. However, it would be beneficial to add a comment specifying the expected file extension for clarity and maintainability.
EventStream/data/dataset_polars.py (1)
  • 13-13: Adding defaultdict from the collections module is a good choice for handling dynamic and potentially unknown keys in dictionaries. This import is used effectively in the exception handling within the _filter_col_inclusion method.

Comment on lines +736 to +753
try:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
except TypeError as e:
incl_targets_by_type = defaultdict(list)
for t in incl_targets:
incl_targets_by_type[str(type(t))].append(t)

by_type_summ = []
for tp, vals in incl_targets_by_type.items():
by_type_summ.append(
f"{tp}: {len(vals)} values: {', '.join(str(x) for x in vals[:5])}..."
)

by_type_summ = "\n".join(by_type_summ)

raise ValueError(
f"Failed to convert incl_targets to {df.schema[col]}:\n{by_type_summ}"
) from e
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes made to the _filter_col_inclusion method enhance its robustness by adding exception handling for type conversion issues. This is particularly useful when dealing with dynamic data types that may not always conform to expected formats. The detailed error message constructed from the incl_targets_by_type dictionary provides clear insight into the nature of the conversion failure, which can significantly aid in debugging.

However, consider adding a comment explaining the purpose of creating a summary of conversion failures by type and values. This will help future maintainers understand the rationale behind these changes more quickly.

+ # Handle type conversion exceptions by summarizing conversion failures by type and values
  try:
      incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
try:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
except TypeError as e:
incl_targets_by_type = defaultdict(list)
for t in incl_targets:
incl_targets_by_type[str(type(t))].append(t)
by_type_summ = []
for tp, vals in incl_targets_by_type.items():
by_type_summ.append(
f"{tp}: {len(vals)} values: {', '.join(str(x) for x in vals[:5])}..."
)
by_type_summ = "\n".join(by_type_summ)
raise ValueError(
f"Failed to convert incl_targets to {df.schema[col]}:\n{by_type_summ}"
) from e
# Handle type conversion exceptions by summarizing conversion failures by type and values
try:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
except TypeError as e:
incl_targets_by_type = defaultdict(list)
for t in incl_targets:
incl_targets_by_type[str(type(t))].append(t)
by_type_summ = []
for tp, vals in incl_targets_by_type.items():
by_type_summ.append(
f"{tp}: {len(vals)} values: {', '.join(str(x) for x in vals[:5])}..."
)
by_type_summ = "\n".join(by_type_summ)
raise ValueError(
f"Failed to convert incl_targets to {df.schema[col]}:\n{by_type_summ}"
) from e

Comment on lines 244 to 289
data_as_lists = defaultdict(list)
for ep in tqdm(range(self.config.cache_for_epochs), total=self.config.cache_for_epochs, leave=False):
for it in tqdm(constructor_pyd, total=len(constructor_pyd)):
items.append(it)

logger.info("Collating data into dense tensors to cache.")
global_batch = constructor_pyd.collate(items, do_convert_float_nans=False)

tensors_to_cache = []
seen_keys = set()
for k, T in global_batch.items():
if k.endswith("_mask") and k != "event_mask":
continue
if T is None:
continue
if isinstance(T, torch.Tensor):
if k in seen_keys:
raise KeyError(f"Duplicate tensor save key {k}!")
tensors_to_cache.append((k, T))
seen_keys.add(k)
elif isinstance(T, dict):
for kk, TT in T.items():
if TT is None:
continue
elif not isinstance(TT, torch.Tensor):
raise TypeError(f"Unrecognized tensor type {type(TT)} @ {k}/{kk}!")

if kk in seen_keys:
raise KeyError(f"Duplicate tensor save key {kk}!")
tensors_to_cache.append((kk, TT))
seen_keys.add(kk)
else:
raise TypeError(f"Unrecognized tensor type {type(T)} @ {k}!")
for k, val in it.items():
data_as_lists[k].append(val)

logger.info("Constructing tensors to cache.")
logger.info(f"Dataset keys: {data_as_lists.keys()}")

sparse_1_keys = ["time_delta", "dynamic_indices", "dynamic_values", "dynamic_measurement_indices"]
sparse_2_keys = ["static_indices", "static_measurement_indices"]
dense_keys = [k for k in data_as_lists.keys() if k not in sparse_1_keys + sparse_2_keys]
tensor_types = {
"subject_id": np.int64,
"static_indices": np.int64,
"static_measurement_indices": np.int64,
}

for k, T in tqdm(tensors_to_cache, leave=False, desc="Caching..."):
fp = self._full_data_config.tensorized_cached_dir / self.split / f"{k}.pt"
fp.parent.mkdir(exist_ok=True, parents=True)
st = datetime.now()
logger.info(f"Caching tensor {k} of shape {T.shape} to {fp}...")
torch.save(T, fp)
logger.info(f"Done in {datetime.now() - st}")
# Dense tensors
dense_tensors = {}
for k in dense_keys:
logger.info(f"Collating {k} as a dense tensor")
dense_tensors[k] = np.array(data_as_lists[k], dtype=tensor_types.get(k, np.float32))

fp = self._full_data_config.tensorized_cached_dir / self.split / "dense.npz"
logger.info(f"Saving dense tensors to {fp}")
save_file(dense_tensors, fp)

# Ragged tensors
logger.info(f"Constructing ragged tensors across {sparse_1_keys}")
sparse_tensors_dict = {k: data_as_lists[k] for k in sparse_1_keys}
sparse_tensors_dict["dynamic_values"] = [
[[v if v is not None else float("nan") for v in measurements] for measurements in events]
for events in sparse_tensors_dict["dynamic_values"]
]
sparse_tensors = JointNestedRaggedTensorDict(sparse_tensors_dict)
fp = self._full_data_config.tensorized_cached_dir / self.split / "sparse_dynamic.npz"
logger.info(f"Saving sparse tensors to {fp}")
sparse_tensors.save(fp)

logger.info(f"Constructing ragged tensors across {sparse_2_keys}")
sparse_tensors_dict = {k: data_as_lists[k] for k in sparse_2_keys}
sparse_tensors = JointNestedRaggedTensorDict(sparse_tensors_dict)
fp = self._full_data_config.tensorized_cached_dir / self.split / "sparse_static.npz"
logger.info(f"Saving sparse tensors to {fp}")
sparse_tensors.save(fp)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method _build_task_cached_df in ConstructorPytorchDataset is complex and critical for handling task-specific data. Consider the following improvements:

  • Simplify the logic by breaking down the method into smaller, more manageable functions.
  • Add detailed comments explaining the purpose and logic of each step, especially for complex operations.
  • Ensure that the handling of time-dependent columns and the construction of task-specific dataframes are correctly implemented and optimized for performance.

@juancq
Copy link
Contributor

juancq commented Mar 25, 2024

Thank you @juancq -- I've been travelling quite a bit and been otherwise occupied for the last month and a half, but I'm trying to push a new major version of ESGPT that addresses these issues and the other memory issues. I assume your last comment still reflects the state of things with this change for you?

Yes, the state of things on my end are the same as of the time I wrote the last comment.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Actionable comments outside the diff hunks (15)
EventStream/data/dataset_polars.py (1)

Line range hint 203-203: Replace type comparison with isinstance() for better practice and compatibility.

- if type(qq) is Path:
+ if isinstance(qq, Path):
EventStream/data/config.py (14)

Line range hint 117-117: Replace type comparison with isinstance() for better practice and readability.

- if type(self.static) is dict:
+ if isinstance(self.static, dict):

Line range hint 125-125: Replace type comparison with isinstance() for better practice and readability.

- if type(v) is dict:
+ if isinstance(v, dict):

Line range hint 407-407: Replace type comparison with isinstance() for better practice and readability.

- if type(self.data_schema) is not list and self.data_schema is not None:
+ if not isinstance(self.data_schema, list) and self.data_schema is not None:

Line range hint 409-409: Replace type comparison with isinstance() for better practice and readability.

- if type(self.start_data_schema) is not list and self.start_data_schema is not None:
+ if not isinstance(self.start_data_schema, list) and self.start_data_schema is not None:

Line range hint 411-411: Replace type comparison with isinstance() for better practice and readability.

- if type(self.end_data_schema) is not list and self.end_data_schema is not None:
+ if not isinstance(self.end_data_schema, list) and self.end_data_schema is not None:

Line range hint 625-625: Replace type comparison with isinstance() for better practice and readability.

- if type(self.min_seq_len) is not int or self.min_seq_len < 0:
+ if not isinstance(self.min_seq_len, int) or self.min_seq_len < 0:

Line range hint 631-631: Replace type comparison with isinstance() for better practice and readability.

- if type(self.max_seq_len) is not int or self.max_seq_len < self.min_seq_len:
+ if not isinstance(self.max_seq_len, int) or self.max_seq_len < self.min_seq_len:

Line range hint 649-649: Replace type comparison with isinstance() for better practice and readability.

- if type(self.save_dir) is str and self.save_dir != omegaconf.MISSING:
+ if isinstance(self.save_dir, str) and self.save_dir != omegaconf.MISSING:

Line range hint 911-911: Replace type comparison with isinstance() for better practice and readability.

- if type(val) is not dict:
+ if not isinstance(val, dict):

Line range hint 913-913: Replace type comparison with isinstance() for better practice and readability.

- if type(val) is not dict:
+ if not isinstance(val, dict):

Line range hint 919-919: Replace type comparison with isinstance() for better practice and readability.

- if type(val) is not dict:
+ if not isinstance(val, dict):

Line range hint 1922-1922: Replace type comparison with isinstance() for better practice and readability.

- if type(self.save_dir) is str:
+ if isinstance(self.save_dir, str):

Line range hint 1931-1931: Replace type comparison with isinstance() for better practice and readability.

- if type(val) is not dict:
+ if not isinstance(val, dict):

Line range hint 1958-1958: Replace type comparison with isinstance() for better practice and readability.

- if type(val) is not dict:
+ if not isinstance(val, dict):

Repository owner deleted a comment from coderabbitai bot Apr 22, 2024
@mmcdermott
Copy link
Owner Author

This is blocked by #104

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Out of diff range and nitpick comments (4)
EventStream/baseline/FT_task_baseline.py (3)

[!TIP]
Codebase Verification

Several files in the codebase still use the old function signature.

  • No direct or indirect tests for the load_flat_rep function were found in the codebase.
  • Ensure that comprehensive tests are added to cover all branches and exception handling for load_flat_rep.
Analysis chain

Line range hint 35-241: Ensure comprehensive testing of load_flat_rep to cover all branches and exception handling.

Scripts Executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Description: Verify comprehensive tests for `load_flat_rep`.

# Test: Search for test cases covering all branches in `load_flat_rep`.
rg --type python $'def test_load_flat_rep'

Length of output: 74



Script:

#!/bin/bash
# Description: Verify comprehensive tests for `load_flat_rep`.

# Test: Search for test cases covering all branches in `load_flat_rep`.
rg --type py 'def test_load_flat_rep'

Length of output: 39



Script:

#!/bin/bash
# Description: Verify comprehensive tests for `load_flat_rep`.

# Test: Search for references to `load_flat_rep` in test files.
rg 'load_flat_rep' --type py --glob '*test*.py'

Length of output: 47



Script:

#!/bin/bash
# Description: Verify comprehensive tests for `load_flat_rep`.

# Test: Search for references to `FT_task_baseline.py` and `load_flat_rep` within the codebase.
rg 'FT_task_baseline' --type py
rg 'load_flat_rep' --type py

Length of output: 305



Script:

#!/bin/bash
# Description: Display the contents of `scripts/sklearn_baseline.py` to check for testing logic.

# Display the contents of the file.
cat scripts/sklearn_baseline.py

Length of output: 798


Line range hint 241-241: Use isinstance() for type checking instead of comparing types directly.

- if type(window_sizes) is not list:
+ if not isinstance(window_sizes, list):

Line range hint 241-241: Use isinstance() for type checking to ensure flexibility and compatibility.

- if type(window_sizes) is not list:
+ if not isinstance(window_sizes, list):
EventStream/data/dataset_polars.py (1)

Line range hint 204-204: Use isinstance() instead of direct type comparison.

- if type(qq) is Path:
+ if isinstance(qq, Path):
Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between 1dfb6ef and 79d41a4.
Files ignored due to path filters (1)
  • pyproject.toml is excluded by !**/*.toml
Files selected for processing (5)
  • EventStream/baseline/FT_task_baseline.py (2 hunks)
  • EventStream/data/dataset_base.py (2 hunks)
  • EventStream/data/dataset_polars.py (3 hunks)
  • EventStream/data/pytorch_dataset.py (8 hunks)
  • tests/data/test_pytorch_dataset.py (7 hunks)
Files not summarized due to errors (1)
  • EventStream/data/pytorch_dataset.py: Error: Message exceeds token limit
Additional Context Used
Ruff (2)
EventStream/baseline/FT_task_baseline.py (1)

241-241: Do not compare types, use isinstance()

EventStream/data/dataset_polars.py (1)

204-204: Do not compare types, use isinstance()

Additional comments not posted (16)
tests/data/test_pytorch_dataset.py (6)

13-13: Import of JointNestedRaggedTensorDict added.

This import is necessary for handling nested ragged tensor data structures, which aligns with the PR's objectives to handle complex tensor operations.


56-58: Calculation of event time deltas.

The logic for calculating time deltas between events is correct and efficiently uses list comprehensions. This is crucial for the time series data manipulations required in the dataset.

Also applies to: 66-68, 77-79


85-86: Introduction of time and time_delta fields in DL_REP_DF.

These fields are essential for representing time series data in a structured format, which is beneficial for downstream processing and analysis.


315-361: Setup and teardown methods in TestPytorchDataset.

The setup and teardown methods are well-structured, ensuring that temporary directories and files are correctly managed. This is crucial for maintaining test isolation and cleanliness.


Line range hint 363-424: Implementation of get_pyd method.

The method is implemented correctly to configure and retrieve a PytorchDataset instance based on the provided configuration. This flexibility is important for testing different configurations.


421-424: Assertion logic in test_normalize_task.

The assertion checks that the normalized values match the expected results, which is crucial for verifying the correctness of the normalization logic.

EventStream/data/pytorch_dataset.py (4)

9-15: Import of JointNestedRaggedTensorDict and related constants.

These imports are necessary for handling complex data structures involving nested and ragged tensors, aligning with the PR's objectives to enhance data handling capabilities.


Line range hint 61-160: Comprehensive restructuring of PytorchDataset class.

The restructuring introduces more robust handling of configurations, vocabulary, and data shards. This enhancement is crucial for managing complex datasets and configurations in a scalable manner.


Line range hint 172-398: Methods for reading and setting configurations and data.

These methods are well-implemented to ensure robust data management and error handling, which is essential for maintaining data integrity and facilitating efficient data processing.


Line range hint 426-611: Implementation of __getitem__ and collate methods.

These methods are crucial for fetching and collating data in a format suitable for PyTorch models. The implementation handles various data types and structures effectively, ensuring compatibility with different model architectures.

EventStream/baseline/FT_task_baseline.py (1)

35-35: Update import to reflect the new class name.

EventStream/data/dataset_base.py (3)

21-21: Ensure the new import polars as pl is utilized appropriately in the code.


24-24: Ensure the new import JointNestedRaggedTensorDict from nested_ragged_tensors.ragged_numpy is utilized appropriately in the code.


1367-1367: The creation of NRT_dir is not followed by any operations within this directory in the provided code. Verify if this is intended or if additional code is missing.

Verification successful

The creation of NRT_dir is followed by operations within this directory, such as creating directories and saving files. Therefore, the code appears to be functioning as intended.

  • NRT_fp.parent.mkdir(exist_ok=True, parents=True)
  • jnrt_dict.save(NRT_fp)
Scripts Executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Description: Check for usage of NRT_dir in the repository.
rg --type py 'NRT_dir'

Length of output: 189



Script:

#!/bin/bash
# Description: Check for usage of NRT_fp in the repository.
rg --type py 'NRT_fp'

Length of output: 821

EventStream/data/dataset_polars.py (2)

726-743: Enhance clarity by adding a comment explaining the purpose of summarizing conversion failures by type and values.


1410-1420: The implementation of time_delta calculation in build_DL_cached_representation looks correct and efficient.

@@ -1358,28 +1360,67 @@ def cache_deep_learning_representation(
"""

logger.info("Caching DL representations")
if subjects_per_output_file is None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider specifying a default value for subjects_per_output_file to avoid potential issues with undefined behavior.

- if subjects_per_output_file is None:
+ if subjects_per_output_file is None:
+     subjects_per_output_file = 100  # Default value, adjust as necessary

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if subjects_per_output_file is None:
if subjects_per_output_file is None:
subjects_per_output_file = 100 # Default value, adjust as necessary


split_cached_df = self._filter_col_inclusion(cached_df, {"subject_id": subjects})
self._write_df(split_cached_df, fp, do_overwrite=do_overwrite)
shards_fp.write_text(json.dumps(shards))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using json.dump directly with a file handle instead of write_text for writing JSON data. This is more idiomatic and efficient.

- shards_fp.write_text(json.dumps(shards))
+ with open(shards_fp, 'w') as f:
+     json.dump(shards, f)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
shards_fp.write_text(json.dumps(shards))
with open(shards_fp, 'w') as f:
json.dump(shards, f)

logger.info(f"Skipping {NRT_fp} as it already exists.")
else:
logger.info(f"Caching NRT for {shard_key} to {NRT_fp}")
# TODO(mmd): This breaks the API isolation a bit, as we assume polars here. But that's fine.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment about API isolation breakage should be addressed or clarified to ensure the design is robust and maintainable.

Consider revising the architecture to maintain API isolation or provide a detailed justification for the current approach.

@@ -660,7 +660,7 @@ def train_sklearn_pipeline(cfg: SklearnConfig):
task_dfs = add_tasks_from(ESD.config.save_dir / "task_dfs")
task_df = task_dfs[cfg.task_df_name]

task_type, normalized_label = ConstructorPytorchDataset.normalize_task(
task_type, normalized_label = PytorchDataset.normalize_task(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding more detailed logging for better traceability during model training.

+ logger.debug(f"Task type: {task_type}, Label: {normalized_label}")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
task_type, normalized_label = PytorchDataset.normalize_task(
task_type, normalized_label = PytorchDataset.normalize_task(
logger.debug(f"Task type: {task_type}, Label: {normalized_label}")

@mmcdermott
Copy link
Owner Author

@juancq I know this update is long overdue, but the recent pushes use the nested ragged tensor code in a new way that should dramatically reduce the CPU memory burden during model training. You'll need to re-build the base dataset object first to produce the right cached files (though I may write a conversion script to ease that cost, if that would be helpful) but once that is done this should, with minimal to no impact on throughput, set things up to only pull the patient data from disk as needed rather than loading it all in memory at all. Note this requires updating to the latest version of nested_ragged_tensors as well. If you try it and find it useful or find issues with it, I'd be very appreciative and curious of your findings!

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between 79d41a4 and 6c20b5d.
Files ignored due to path filters (1)
  • pyproject.toml is excluded by !**/*.toml
Files selected for processing (1)
  • EventStream/data/types.py (1 hunks)
Files skipped from review due to trivial changes (1)
  • EventStream/data/types.py

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between 6c20b5d and 330f6ac.
Files selected for processing (1)
  • EventStream/data/pytorch_dataset.py (8 hunks)
Files not summarized due to errors (1)
  • EventStream/data/pytorch_dataset.py: Error: Message exceeds token limit
Additional comments not posted (12)
EventStream/data/pytorch_dataset.py (12)

9-15: Imports from nested_ragged_tensors.ragged_numpy are added to handle nested and ragged tensor data.


61-62: Introduction of the PytorchDataset class with detailed documentation.


132-156: The constructor (__init__) has been expanded with several new logging statements and method calls to set up the dataset based on the configuration. This includes reading vocabulary, shards, and patient descriptors, and applying filters based on sequence length and training subset size.


169-179: Methods read_vocabulary and read_shards are introduced to load configuration and data shard information from disk. These methods are crucial for setting up the dataset and ensuring that data is loaded efficiently.


185-215: The read_patient_descriptors method has been significantly expanded to handle loading and processing of patient descriptors. This includes reading parquet files and handling potential duplicate subject entries, which could lead to data inconsistencies.


226-242: Handling of task-specific dataframes and information has been added. This includes reading and validating task constraints and information from disk, and handling discrepancies in task information. This is critical for ensuring that the dataset is correctly set up for specific tasks.


247-287: Complex data manipulation and joining operations are performed to align task-specific data with the main dataset. This includes handling of start and end indices for tasks, which is crucial for correct data slicing during model training.


296-318: The get_task_info method is introduced to extract and normalize task information from the dataframe. This includes determining the type of task (classification, regression) and setting up vocabulary for tasks. This method is essential for configuring the dataset to handle different types of tasks correctly.


322-342: Methods filter_to_min_seq_len and filter_to_subset have been added to apply filters on the dataset based on minimum sequence length and subset size. These methods are important for ensuring that the dataset meets specific training requirements.


Line range hint 365-396: The method set_inter_event_time_stats calculates statistics on inter-event times and handles cases where these times are invalid (<=0). This is crucial for ensuring data quality and consistency.


Line range hint 424-509: The __getitem__ method and its helper _seeded_getitem have been significantly expanded to handle dynamic and static data indices, and to apply subsequence sampling strategies based on the configuration. This is critical for preparing data for model input.


Line range hint 511-609: The collate method and its helper __dynamic_only_collate have been updated to handle the collation of batch data, including handling of ragged tensor data and padding of static data fields. This is essential for preparing batches of data for model training.

@juancq
Copy link
Contributor

juancq commented May 20, 2024

@mmcdermott thanks for all the hard work. I tested this branch on my dataset. The previous bugs are gone. I am now seeing about a 7% runtime improvement per epoch and about 30% lower memory usage.

@mmcdermott
Copy link
Owner Author

Fantastic! Thanks so much @juancq . I'll do some final testing just to make sure there are no issues and plan to merge this branch in soon. Glad this has resolved your issues and induced other improvements besides.

@mmcdermott mmcdermott merged commit 0350d7c into dev May 31, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants