Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes the slowdowns and bugs caused by the prior improved compute practices, but requires a nested tensor package. #90

Merged
merged 30 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
007fd49
Updated error message
mmcdermott Dec 18, 2023
ca65a71
Attempting to try using ragged tensors from https://github.com/mmcder…
mmcdermott Dec 18, 2023
5947a91
Fixed a few small errors
mmcdermott Dec 19, 2023
774a9bc
Fixed small typo; may or may not still be working.
mmcdermott Dec 19, 2023
d0ee233
Fixed small typo
mmcdermott Dec 19, 2023
5d3b403
Improved collate fn.
mmcdermott Dec 19, 2023
621a939
First working version
mmcdermott Dec 19, 2023
9f5013b
A version using numpy instead of torch for collation and such.
mmcdermott Dec 19, 2023
c05428b
Further optimizations
mmcdermott Dec 20, 2023
e346808
Added _cache_subset
pargaw Dec 22, 2023
b22e6ce
Fixed some logger typos
pargaw Dec 22, 2023
a023b62
temporarily set max recursion limit to account for large subset sizes
pargaw Dec 22, 2023
29c3b9f
Removed unused comment
pargaw Dec 22, 2023
8c73fd3
Make logging more detailed.
mmcdermott Jan 4, 2024
78a1aac
Merge branch 'using_ragged_tensors' of github.com:mmcdermott/EventStr…
mmcdermott Jan 4, 2024
4245128
Fixed some small typos and made cached subsets actually be re-loaded …
mmcdermott Jan 4, 2024
35bacf2
Cache data_parameters for subset sizes
pargaw Jan 4, 2024
6095101
Fixed lint errors.
mmcdermott Jan 9, 2024
fa33387
Added nested_ragged_tensors as a dependency.
mmcdermott Jan 22, 2024
475a7f0
Merged.
mmcdermott Apr 22, 2024
fd20002
Merge branch 'dev' into using_ragged_tensors
mmcdermott Apr 22, 2024
1dfb6ef
Merge branch 'dev' into using_ragged_tensors
mmcdermott Apr 22, 2024
2873344
Merge branch 'dev' into using_ragged_tensors
mmcdermott Apr 24, 2024
476661a
Updated to write NRT files and use NRT files in a smarter, much less …
mmcdermott May 16, 2024
cd9a661
Some improvements to the test code; pytorch tests are currently failing.
mmcdermott May 17, 2024
6a370bf
Removing out-dated tests with new pytorch dataset format.
mmcdermott May 17, 2024
79d41a4
Fixed pytorch dataset tests (mostly by removing those that were faili…
mmcdermott May 17, 2024
6c20b5d
Further updated polars and fixed a small test case that the polars ch…
mmcdermott May 17, 2024
330f6ac
Makes measurement_configs a property instead of a static access.
mmcdermott May 18, 2024
2f433a6
Merge branch 'dev' into using_ragged_tensors
mmcdermott May 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions EventStream/baseline/FT_task_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from sklearn.preprocessing import MinMaxScaler, StandardScaler

from ..data.dataset_polars import Dataset
from ..data.pytorch_dataset import ConstructorPytorchDataset
from ..data.pytorch_dataset import PytorchDataset
from ..tasks.profile import add_tasks_from
from ..utils import task_wrapper

Expand Down Expand Up @@ -660,7 +660,7 @@ def train_sklearn_pipeline(cfg: SklearnConfig):
task_dfs = add_tasks_from(ESD.config.save_dir / "task_dfs")
task_df = task_dfs[cfg.task_df_name]

task_type, normalized_label = ConstructorPytorchDataset.normalize_task(
task_type, normalized_label = PytorchDataset.normalize_task(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding more detailed logging for better traceability during model training.

+ logger.debug(f"Task type: {task_type}, Label: {normalized_label}")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
task_type, normalized_label = PytorchDataset.normalize_task(
task_type, normalized_label = PytorchDataset.normalize_task(
logger.debug(f"Task type: {task_type}, Label: {normalized_label}")

pl.col(cfg.finetuning_task_label), task_df.schema[cfg.finetuning_task_label]
)

Expand Down
18 changes: 1 addition & 17 deletions EventStream/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,23 +1059,7 @@ def tensorized_cached_files(self, split: str) -> dict[str, Path]:
if not (self.tensorized_cached_dir / split).is_dir():
return {}

all_files = {fp.stem: fp for fp in (self.tensorized_cached_dir / split).glob("*.pt")}
files_str = ", ".join(all_files.keys())

for param, need_keys in [
("do_include_start_time_min", ["start_time"]),
("do_include_subsequence_indices", ["start_idx", "end_idx"]),
("do_include_subject_id", ["subject_id"]),
]:
param_val = getattr(self, param)
for need_key in need_keys:
if param_val:
if need_key not in all_files.keys():
raise KeyError(f"Missing {need_key} but {param} is True! Have {files_str}")
elif need_key in all_files:
all_files.pop(need_key)

return all_files
return {fp.stem: fp for fp in (self.tensorized_cached_dir / split).glob("*.npz")}


@dataclasses.dataclass
Expand Down
71 changes: 56 additions & 15 deletions EventStream/data/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import humanize
import numpy as np
import pandas as pd
import polars as pl
from loguru import logger
from mixins import SaveableMixin, SeedableMixin, TimeableMixin, TQDMableMixin
from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict
from plotly.graph_objs._figure import Figure
from tqdm.auto import tqdm

Expand Down Expand Up @@ -1358,28 +1360,67 @@ def cache_deep_learning_representation(
"""

logger.info("Caching DL representations")
if subjects_per_output_file is None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider specifying a default value for subjects_per_output_file to avoid potential issues with undefined behavior.

- if subjects_per_output_file is None:
+ if subjects_per_output_file is None:
+     subjects_per_output_file = 100  # Default value, adjust as necessary

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if subjects_per_output_file is None:
if subjects_per_output_file is None:
subjects_per_output_file = 100 # Default value, adjust as necessary

logger.warning("Sharding is recommended for DL representations.")

DL_dir = self.config.save_dir / "DL_reps"
DL_dir.mkdir(exist_ok=True, parents=True)
NRT_dir = self.config.save_dir / "NRT_reps"

if subjects_per_output_file is None:
subject_chunks = [None]
shards_fp = self.config.save_dir / "DL_shards.json"
if shards_fp.exists():
shards = json.loads(shards_fp.read_text())
else:
subjects = np.random.permutation(list(self.subject_ids))
subject_chunks = np.array_split(
subjects,
np.arange(subjects_per_output_file, len(subjects), subjects_per_output_file),
)
subject_chunks = [list(c) for c in subject_chunks]
shards = {}

if subjects_per_output_file is None:
subject_chunks = [self.subject_ids]
else:
subjects = np.random.permutation(list(self.subject_ids))
subject_chunks = np.array_split(
subjects,
np.arange(subjects_per_output_file, len(subjects), subjects_per_output_file),
)

for chunk_idx, subjects_list in self._tqdm(list(enumerate(subject_chunks))):
cached_df = self.build_DL_cached_representation(subject_ids=subjects_list)
subject_chunks = [[int(x) for x in c] for c in subject_chunks]

for split, subjects in self.split_subjects.items():
fp = DL_dir / f"{split}_{chunk_idx}.{self.DF_SAVE_FORMAT}"
for chunk_idx, subjects_list in enumerate(subject_chunks):
for split, subjects in self.split_subjects.items():
shard_key = f"{split}/{chunk_idx}"
included_subjects = set(subjects_list).intersection({int(x) for x in subjects})
shards[shard_key] = list(included_subjects)

split_cached_df = self._filter_col_inclusion(cached_df, {"subject_id": subjects})
self._write_df(split_cached_df, fp, do_overwrite=do_overwrite)
shards_fp.write_text(json.dumps(shards))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using json.dump directly with a file handle instead of write_text for writing JSON data. This is more idiomatic and efficient.

- shards_fp.write_text(json.dumps(shards))
+ with open(shards_fp, 'w') as f:
+     json.dump(shards, f)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
shards_fp.write_text(json.dumps(shards))
with open(shards_fp, 'w') as f:
json.dump(shards, f)


for shard_key, subjects_list in self._tqdm(list(shards.items()), desc="Shards"):
DL_fp = DL_dir / f"{shard_key}.{self.DF_SAVE_FORMAT}"
DL_fp.parent.mkdir(exist_ok=True, parents=True)

if DL_fp.exists() and not do_overwrite:
logger.info(f"Skipping {DL_fp} as it already exists.")
cached_df = self._read_df(DL_fp)
else:
logger.info(f"Caching {shard_key} to {DL_fp}")
cached_df = self.build_DL_cached_representation(subject_ids=subjects_list)
self._write_df(cached_df, DL_fp, do_overwrite=do_overwrite)

NRT_fp = NRT_dir / f"{shard_key}.pt"
NRT_fp.parent.mkdir(exist_ok=True, parents=True)
if NRT_fp.exists() and not do_overwrite:
logger.info(f"Skipping {NRT_fp} as it already exists.")
else:
logger.info(f"Caching NRT for {shard_key} to {NRT_fp}")
# TODO(mmd): This breaks the API isolation a bit, as we assume polars here. But that's fine.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment about API isolation breakage should be addressed or clarified to ensure the design is robust and maintainable.

Consider revising the architecture to maintain API isolation or provide a detailed justification for the current approach.

jnrt_dict = {
k: cached_df[k].to_list()
for k in ["time_delta", "dynamic_indices", "dynamic_measurement_indices"]
}
jnrt_dict["dynamic_values"] = (
cached_df["dynamic_values"]
.list.eval(pl.element().list.eval(pl.element().fill_null(float("nan"))))
.to_list()
)
jnrt_dict = JointNestedRaggedTensorDict(jnrt_dict)
jnrt_dict.save(NRT_fp)

@property
def vocabulary_config(self) -> VocabularyConfig:
Expand Down
27 changes: 25 additions & 2 deletions EventStream/data/dataset_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import dataclasses
import math
import multiprocessing
from collections import defaultdict
from collections.abc import Callable, Sequence
from datetime import timedelta
from pathlib import Path
Expand Down Expand Up @@ -722,7 +723,24 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool |
case False:
filter_exprs.append(pl.col(col).is_null())
case _:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
try:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
except TypeError as e:
incl_targets_by_type = defaultdict(list)
for t in incl_targets:
incl_targets_by_type[str(type(t))].append(t)

by_type_summ = []
for tp, vals in incl_targets_by_type.items():
by_type_summ.append(
f"{tp}: {len(vals)} values: {', '.join(str(x) for x in vals[:5])}..."
)

by_type_summ = "\n".join(by_type_summ)

raise ValueError(
f"Failed to convert incl_targets to {df.schema[col]}:\n{by_type_summ}"
) from e
Comment on lines +726 to +743
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes made to the _filter_col_inclusion method enhance its robustness by adding exception handling for type conversion issues. This is particularly useful when dealing with dynamic data types that may not always conform to expected formats. The detailed error message constructed from the incl_targets_by_type dictionary provides clear insight into the nature of the conversion failure, which can significantly aid in debugging.

However, consider adding a comment explaining the purpose of creating a summary of conversion failures by type and values. This will help future maintainers understand the rationale behind these changes more quickly.

+ # Handle type conversion exceptions by summarizing conversion failures by type and values
  try:
      incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
try:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
except TypeError as e:
incl_targets_by_type = defaultdict(list)
for t in incl_targets:
incl_targets_by_type[str(type(t))].append(t)
by_type_summ = []
for tp, vals in incl_targets_by_type.items():
by_type_summ.append(
f"{tp}: {len(vals)} values: {', '.join(str(x) for x in vals[:5])}..."
)
by_type_summ = "\n".join(by_type_summ)
raise ValueError(
f"Failed to convert incl_targets to {df.schema[col]}:\n{by_type_summ}"
) from e
# Handle type conversion exceptions by summarizing conversion failures by type and values
try:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
except TypeError as e:
incl_targets_by_type = defaultdict(list)
for t in incl_targets:
incl_targets_by_type[str(type(t))].append(t)
by_type_summ = []
for tp, vals in incl_targets_by_type.items():
by_type_summ.append(
f"{tp}: {len(vals)} values: {', '.join(str(x) for x in vals[:5])}..."
)
by_type_summ = "\n".join(by_type_summ)
raise ValueError(
f"Failed to convert incl_targets to {df.schema[col]}:\n{by_type_summ}"
) from e

filter_exprs.append(pl.col(col).is_in(incl_list))

return df.filter(pl.all_horizontal(filter_exprs))
Expand Down Expand Up @@ -1389,12 +1407,17 @@ def build_DL_cached_representation(
pl.col("value").alias("dynamic_values"),
)
.sort("subject_id", "timestamp")
.group_by("subject_id")
.group_by("subject_id", maintain_order=True)
.agg(
pl.col("timestamp").first().alias("start_time"),
((pl.col("timestamp") - pl.col("timestamp").min()).dt.total_nanoseconds() / (1e9 * 60)).alias(
"time"
),
(pl.col("timestamp").diff().dt.total_seconds() / 60.0)
.shift(-1)
.cast(pl.Float32)
.fill_null(float("nan"))
.alias("time_delta"),
pl.col("dynamic_measurement_indices"),
pl.col("dynamic_indices"),
pl.col("dynamic_values"),
Expand Down
Loading
Loading