-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes the slowdowns and bugs caused by the prior improved compute practices, but requires a nested tensor package. #90
Changes from 27 commits
007fd49
ca65a71
5947a91
774a9bc
d0ee233
5d3b403
621a939
9f5013b
c05428b
e346808
b22e6ce
a023b62
29c3b9f
8c73fd3
78a1aac
4245128
35bacf2
6095101
fa33387
475a7f0
fd20002
1dfb6ef
2873344
476661a
cd9a661
6a370bf
79d41a4
6c20b5d
330f6ac
2f433a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -18,8 +18,10 @@ | |||||||
import humanize | ||||||||
import numpy as np | ||||||||
import pandas as pd | ||||||||
import polars as pl | ||||||||
from loguru import logger | ||||||||
from mixins import SaveableMixin, SeedableMixin, TimeableMixin, TQDMableMixin | ||||||||
from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict | ||||||||
from plotly.graph_objs._figure import Figure | ||||||||
from tqdm.auto import tqdm | ||||||||
|
||||||||
|
@@ -1358,28 +1360,67 @@ def cache_deep_learning_representation( | |||||||
""" | ||||||||
|
||||||||
logger.info("Caching DL representations") | ||||||||
if subjects_per_output_file is None: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider specifying a default value for - if subjects_per_output_file is None:
+ if subjects_per_output_file is None:
+ subjects_per_output_file = 100 # Default value, adjust as necessary Committable suggestion
Suggested change
|
||||||||
logger.warning("Sharding is recommended for DL representations.") | ||||||||
|
||||||||
DL_dir = self.config.save_dir / "DL_reps" | ||||||||
DL_dir.mkdir(exist_ok=True, parents=True) | ||||||||
NRT_dir = self.config.save_dir / "NRT_reps" | ||||||||
|
||||||||
if subjects_per_output_file is None: | ||||||||
subject_chunks = [None] | ||||||||
shards_fp = self.config.save_dir / "DL_shards.json" | ||||||||
if shards_fp.exists(): | ||||||||
shards = json.loads(shards_fp.read_text()) | ||||||||
else: | ||||||||
subjects = np.random.permutation(list(self.subject_ids)) | ||||||||
subject_chunks = np.array_split( | ||||||||
subjects, | ||||||||
np.arange(subjects_per_output_file, len(subjects), subjects_per_output_file), | ||||||||
) | ||||||||
subject_chunks = [list(c) for c in subject_chunks] | ||||||||
shards = {} | ||||||||
|
||||||||
if subjects_per_output_file is None: | ||||||||
subject_chunks = [self.subject_ids] | ||||||||
else: | ||||||||
subjects = np.random.permutation(list(self.subject_ids)) | ||||||||
subject_chunks = np.array_split( | ||||||||
subjects, | ||||||||
np.arange(subjects_per_output_file, len(subjects), subjects_per_output_file), | ||||||||
) | ||||||||
|
||||||||
for chunk_idx, subjects_list in self._tqdm(list(enumerate(subject_chunks))): | ||||||||
cached_df = self.build_DL_cached_representation(subject_ids=subjects_list) | ||||||||
subject_chunks = [[int(x) for x in c] for c in subject_chunks] | ||||||||
|
||||||||
for split, subjects in self.split_subjects.items(): | ||||||||
fp = DL_dir / f"{split}_{chunk_idx}.{self.DF_SAVE_FORMAT}" | ||||||||
for chunk_idx, subjects_list in enumerate(subject_chunks): | ||||||||
for split, subjects in self.split_subjects.items(): | ||||||||
shard_key = f"{split}/{chunk_idx}" | ||||||||
included_subjects = set(subjects_list).intersection({int(x) for x in subjects}) | ||||||||
shards[shard_key] = list(included_subjects) | ||||||||
|
||||||||
split_cached_df = self._filter_col_inclusion(cached_df, {"subject_id": subjects}) | ||||||||
self._write_df(split_cached_df, fp, do_overwrite=do_overwrite) | ||||||||
shards_fp.write_text(json.dumps(shards)) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using - shards_fp.write_text(json.dumps(shards))
+ with open(shards_fp, 'w') as f:
+ json.dump(shards, f) Committable suggestion
Suggested change
|
||||||||
|
||||||||
for shard_key, subjects_list in self._tqdm(list(shards.items()), desc="Shards"): | ||||||||
DL_fp = DL_dir / f"{shard_key}.{self.DF_SAVE_FORMAT}" | ||||||||
DL_fp.parent.mkdir(exist_ok=True, parents=True) | ||||||||
|
||||||||
if DL_fp.exists() and not do_overwrite: | ||||||||
logger.info(f"Skipping {DL_fp} as it already exists.") | ||||||||
cached_df = self._read_df(DL_fp) | ||||||||
else: | ||||||||
logger.info(f"Caching {shard_key} to {DL_fp}") | ||||||||
cached_df = self.build_DL_cached_representation(subject_ids=subjects_list) | ||||||||
self._write_df(cached_df, DL_fp, do_overwrite=do_overwrite) | ||||||||
|
||||||||
NRT_fp = NRT_dir / f"{shard_key}.pt" | ||||||||
NRT_fp.parent.mkdir(exist_ok=True, parents=True) | ||||||||
if NRT_fp.exists() and not do_overwrite: | ||||||||
logger.info(f"Skipping {NRT_fp} as it already exists.") | ||||||||
else: | ||||||||
logger.info(f"Caching NRT for {shard_key} to {NRT_fp}") | ||||||||
# TODO(mmd): This breaks the API isolation a bit, as we assume polars here. But that's fine. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment about API isolation breakage should be addressed or clarified to ensure the design is robust and maintainable. Consider revising the architecture to maintain API isolation or provide a detailed justification for the current approach. |
||||||||
jnrt_dict = { | ||||||||
k: cached_df[k].to_list() | ||||||||
for k in ["time_delta", "dynamic_indices", "dynamic_measurement_indices"] | ||||||||
} | ||||||||
jnrt_dict["dynamic_values"] = ( | ||||||||
cached_df["dynamic_values"] | ||||||||
.list.eval(pl.element().list.eval(pl.element().fill_null(float("nan")))) | ||||||||
.to_list() | ||||||||
) | ||||||||
jnrt_dict = JointNestedRaggedTensorDict(jnrt_dict) | ||||||||
jnrt_dict.save(NRT_fp) | ||||||||
|
||||||||
@property | ||||||||
def vocabulary_config(self) -> VocabularyConfig: | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -10,6 +10,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import dataclasses | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import math | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import multiprocessing | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from collections import defaultdict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from collections.abc import Callable, Sequence | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from datetime import timedelta | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -722,7 +723,24 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool | | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
case False: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
filter_exprs.append(pl.col(col).is_null()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
case _: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
except TypeError as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
incl_targets_by_type = defaultdict(list) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for t in incl_targets: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
incl_targets_by_type[str(type(t))].append(t) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
by_type_summ = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for tp, vals in incl_targets_by_type.items(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
by_type_summ.append( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"{tp}: {len(vals)} values: {', '.join(str(x) for x in vals[:5])}..." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
by_type_summ = "\n".join(by_type_summ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"Failed to convert incl_targets to {df.schema[col]}:\n{by_type_summ}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) from e | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+726
to
+743
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The changes made to the However, consider adding a comment explaining the purpose of creating a summary of conversion failures by type and values. This will help future maintainers understand the rationale behind these changes more quickly. + # Handle type conversion exceptions by summarizing conversion failures by type and values
try:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
filter_exprs.append(pl.col(col).is_in(incl_list)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return df.filter(pl.all_horizontal(filter_exprs)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1389,12 +1407,17 @@ def build_DL_cached_representation( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pl.col("value").alias("dynamic_values"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.sort("subject_id", "timestamp") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.group_by("subject_id") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.group_by("subject_id", maintain_order=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.agg( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pl.col("timestamp").first().alias("start_time"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
((pl.col("timestamp") - pl.col("timestamp").min()).dt.total_nanoseconds() / (1e9 * 60)).alias( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"time" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
(pl.col("timestamp").diff().dt.total_seconds() / 60.0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.shift(-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.cast(pl.Float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.fill_null(float("nan")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.alias("time_delta"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pl.col("dynamic_measurement_indices"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pl.col("dynamic_indices"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pl.col("dynamic_values"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding more detailed logging for better traceability during model training.
+ logger.debug(f"Task type: {task_type}, Label: {normalized_label}")
Committable suggestion