Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve all aspects of compute performance (save disk space cost) for pytorch datasets by pre-caching processed items. #76

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions EventStream/baseline/FT_task_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from sklearn.preprocessing import MinMaxScaler, StandardScaler

from ..data.dataset_polars import Dataset
from ..data.pytorch_dataset import PytorchDataset
from ..data.pytorch_dataset import ConstructorPytorchDataset
from ..tasks.profile import add_tasks_from
from ..utils import task_wrapper

Expand Down Expand Up @@ -658,7 +658,7 @@ def train_sklearn_pipeline(cfg: SklearnConfig):
task_dfs = add_tasks_from(ESD.config.save_dir / "task_dfs")
task_df = task_dfs[cfg.task_df_name]

task_type, normalized_label = PytorchDataset.normalize_task(
task_type, normalized_label = ConstructorPytorchDataset.normalize_task(
pl.col(cfg.finetuning_task_label), task_df.schema[cfg.finetuning_task_label]
)

Expand Down
124 changes: 124 additions & 0 deletions EventStream/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import dataclasses
import enum
import hashlib
import json
import random
from collections import OrderedDict, defaultdict
from collections.abc import Hashable, Sequence
Expand Down Expand Up @@ -880,7 +882,19 @@
do_include_subject_id: bool = False
do_include_start_time_min: bool = False

# Trades off between speed/disk/mem and support
cache_for_epochs: int = 1

def __post_init__(self):
if self.cache_for_epochs is None:
self.cache_for_epochs = 1

if self.subsequence_sampling_strategy != "random" and self.cache_for_epochs > 1:
raise ValueError(

Check warning on line 893 in EventStream/data/config.py

View check run for this annotation

Codecov / codecov/patch

EventStream/data/config.py#L893

Added line #L893 was not covered by tests
f"It does not make sense to cache for {self.cache_for_epochs} with non-random "
"subsequence sampling."
)

if self.seq_padding_side not in SeqPaddingSide.values():
raise ValueError(f"seq_padding_side invalid; must be in {', '.join(SeqPaddingSide.values())}")
if type(self.min_seq_len) is not int or self.min_seq_len < 0:
Expand Down Expand Up @@ -920,6 +934,116 @@
as_dict["save_dir"] = Path(as_dict["save_dir"])
return cls(**as_dict)

@property
def vocabulary_config_fp(self) -> Path:
return self.save_dir / "vocabulary_config.json"

@property
def vocabulary_config(self) -> VocabularyConfig:
return VocabularyConfig.from_json_file(self.vocabulary_config_fp)

@property
def measurement_config_fp(self) -> Path:
return self.save_dir / "inferred_measurement_configs.json"

@property
def measurement_configs(self) -> dict[str, MeasurementConfig]:
with open(self.measurement_config_fp) as f:
measurement_configs = {k: MeasurementConfig.from_dict(v) for k, v in json.load(f).items()}
return {k: v for k, v in measurement_configs.items() if not v.is_dropped}

@property
def DL_reps_dir(self) -> Path:
return self.save_dir / "DL_reps"

@property
def cached_task_dir(self) -> Path | None:
if self.task_df_name is None:
return None

Check warning on line 962 in EventStream/data/config.py

View check run for this annotation

Codecov / codecov/patch

EventStream/data/config.py#L962

Added line #L962 was not covered by tests
else:
return self.save_dir / "DL_reps" / "for_task" / self.task_df_name

@property
def raw_task_df_fp(self) -> Path | None:
if self.task_df_name is None:
return None

Check warning on line 969 in EventStream/data/config.py

View check run for this annotation

Codecov / codecov/patch

EventStream/data/config.py#L969

Added line #L969 was not covered by tests
else:
return self.save_dir / "task_dfs" / f"{self.task_df_name}.parquet"

@property
def task_info_fp(self) -> Path | None:
if self.task_df_name is None:
return None

Check warning on line 976 in EventStream/data/config.py

View check run for this annotation

Codecov / codecov/patch

EventStream/data/config.py#L976

Added line #L976 was not covered by tests
else:
return self.cached_task_dir / "task_info.json"

@property
def _data_parameters_and_hash(self) -> tuple[dict[str, Any], str]:
params = sorted(
(
"save_dir",
"max_seq_len",
"min_seq_len",
"seq_padding_side",
"subsequence_sampling_strategy",
"train_subset_size",
"train_subset_seed",
"task_df_name",
)
)

params_list = []
for p in params:
v = str(getattr(self, p))
params_list.append((p, v))

params = tuple(params_list)
h = hashlib.blake2b(digest_size=8)
h.update(str(params).encode())

return {k: v for k, v in params}, h.hexdigest()

@property
def tensorized_cached_dir(self) -> Path:
if self.task_df_name is None:
base_dir = self.DL_reps_dir / "tensorized_cached"
else:
base_dir = self.cached_task_dir

return base_dir / self._data_parameters_and_hash[1]

@property
def _cached_data_parameters_fp(self) -> Path:
return self.tensorized_cached_dir / "data_parameters.json"

def _cache_data_parameters(self):
self._cached_data_parameters_fp.parent.mkdir(exist_ok=True, parents=True)

with open(self._cached_data_parameters_fp, mode="w") as f:
json.dump(self._data_parameters_and_hash[0], f)

def tensorized_cached_files(self, split: str) -> dict[str, Path]:
if not (self.tensorized_cached_dir / split).is_dir():
return {}

all_files = {fp.stem: fp for fp in (self.tensorized_cached_dir / split).glob("*.pt")}
files_str = ", ".join(all_files.keys())

for param, need_keys in [
("do_include_start_time_min", ["start_time"]),
("do_include_subsequence_indices", ["start_idx", "end_idx"]),
("do_include_subject_id", ["subject_id"]),
]:
param_val = getattr(self, param)
for need_key in need_keys:
if param_val:
if need_key not in all_files.keys():
raise KeyError(f"Missing {need_key} but {param} is True! Have {files_str}")

Check warning on line 1041 in EventStream/data/config.py

View check run for this annotation

Codecov / codecov/patch

EventStream/data/config.py#L1041

Added line #L1041 was not covered by tests
elif need_key in all_files:
all_files.pop(need_key)

return all_files


@dataclasses.dataclass
class MeasurementConfig(JSONableMixin):
Expand Down
Loading
Loading