From 291a2e90464ed1977009184d4ae9d8a945fed29a Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 19 Sep 2024 21:38:07 +0800 Subject: [PATCH] improve dataset saving unify val to validation make `save_dataset` support saving multiple files allow passing filename in `save_dataset` Signed-off-by: Zhiyuan Chen --- multimolecule/data/utils.py | 12 +++++------- multimolecule/datasets/bprna_new/bprna_new.py | 2 +- multimolecule/datasets/bprna_spot/bprna_spot.py | 14 ++------------ multimolecule/datasets/conversion_utils.py | 16 ++++++++++++++-- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/multimolecule/data/utils.py b/multimolecule/data/utils.py index de842bc4..85bc4423 100644 --- a/multimolecule/data/utils.py +++ b/multimolecule/data/utils.py @@ -42,10 +42,10 @@ def infer_task( ) -> Task: if max_seq_length is not None and seq_length_offset is not None: max_seq_length -= seq_length_offset - if isinstance(sequence, ChunkedArray) and sequence.num_chunks == 1: - sequence = sequence.chunks[0] - if isinstance(column, ChunkedArray) and column.num_chunks == 1: - column = column.chunks[0] + if isinstance(sequence, ChunkedArray): + sequence = sequence.combine_chunks() + if isinstance(column, ChunkedArray): + column = column.combine_chunks() flattened, levels = flatten_column(column, truncation, max_seq_length) dtype = flattened.type unique = flattened.unique() @@ -145,14 +145,12 @@ def flatten_column( def get_num_tokens(sequence: Array | ListArray, seq_length_offset: int | None = None) -> Tuple[int, int]: - if isinstance(sequence, StringArray): + if isinstance(sequence, StringArray) or isinstance(sequence[0], pa.lib.StringScalar): return sum(len(i.as_py()) for i in sequence), sum(len(i.as_py()) ** 2 for i in sequence) # remove and tokens in length calculation if seq_length_offset is None: warn("seq_length_offset not specified, automatically detecting and tokens") seq_length_offset = 0 - if isinstance(sequence[0], pa.lib.StringScalar): - raise ValueError("seq_length_offset must be specified for StringScalar sequences") if len({i[0] for i in sequence}) == 1: seq_length_offset += 1 if len({i[-1] for i in sequence}) == 1: diff --git a/multimolecule/datasets/bprna_new/bprna_new.py b/multimolecule/datasets/bprna_new/bprna_new.py index 218737ec..d622f9a8 100644 --- a/multimolecule/datasets/bprna_new/bprna_new.py +++ b/multimolecule/datasets/bprna_new/bprna_new.py @@ -48,7 +48,7 @@ def convert_bpseq(bpseq): def convert_dataset(convert_config): data = [convert_bpseq(file) for file in tqdm(get_files(convert_config.dataset_path))] - save_dataset(convert_config, data) + save_dataset(convert_config, data, filename="test.parquet") class ConvertConfig(ConvertConfig_): diff --git a/multimolecule/datasets/bprna_spot/bprna_spot.py b/multimolecule/datasets/bprna_spot/bprna_spot.py index 3b5cb105..76e90b92 100644 --- a/multimolecule/datasets/bprna_spot/bprna_spot.py +++ b/multimolecule/datasets/bprna_spot/bprna_spot.py @@ -17,27 +17,17 @@ from __future__ import annotations import os -from collections.abc import Mapping import torch from tqdm import tqdm from multimolecule.datasets.bprna.bprna import convert_sta from multimolecule.datasets.conversion_utils import ConvertConfig as ConvertConfig_ -from multimolecule.datasets.conversion_utils import copy_readme, get_files, push_to_hub, write_data +from multimolecule.datasets.conversion_utils import get_files, save_dataset torch.manual_seed(1016) -def save_dataset(convert_config: ConvertConfig, data: Mapping, compression: str = "brotli", level: int = 4): - root, output_path = convert_config.root, convert_config.output_path - os.makedirs(output_path, exist_ok=True) - for name, d in data.items(): - write_data(d, output_path, name + ".parquet", compression, level) - copy_readme(root, output_path) - push_to_hub(convert_config, output_path) - - def _convert_dataset(dataset): files = get_files(dataset) return [convert_sta(file) for file in tqdm(files, total=len(files))] @@ -46,7 +36,7 @@ def _convert_dataset(dataset): def convert_dataset(convert_config): data = { "train": _convert_dataset(os.path.join(convert_config.dataset_path, "TR0")), - "val": _convert_dataset(os.path.join(convert_config.dataset_path, "VL0")), + "validation": _convert_dataset(os.path.join(convert_config.dataset_path, "VL0")), "test": _convert_dataset(os.path.join(convert_config.dataset_path, "TS0")), } save_dataset(convert_config, data) diff --git a/multimolecule/datasets/conversion_utils.py b/multimolecule/datasets/conversion_utils.py index 14cc9c80..7bf90012 100644 --- a/multimolecule/datasets/conversion_utils.py +++ b/multimolecule/datasets/conversion_utils.py @@ -18,6 +18,8 @@ import os import shutil +from collections.abc import Mapping +from warnings import warn import pyarrow as pa from chanfig import Config @@ -74,11 +76,21 @@ def push_to_hub(convert_config: ConvertConfig, output_path: str, repo_type: str def save_dataset( - convert_config: ConvertConfig, data: Table | list | dict | DataFrame, compression: str = "brotli", level: int = 4 + convert_config: ConvertConfig, + data: Table | list | dict | DataFrame, + filename: str = "data.parquet", + compression: str = "brotli", + level: int = 4, ): root, output_path = convert_config.root, convert_config.output_path os.makedirs(output_path, exist_ok=True) - write_data(data, output_path, compression=compression, level=level) + if isinstance(data, Mapping): + if filename != "data.parquet": + warn("Filename is ignored when saving multiple datasets.") + for name, d in data.items(): + write_data(d, output_path, filename=name + ".parquet", compression=compression, level=level) + else: + write_data(data, output_path, filename=filename, compression=compression, level=level) copy_readme(root, output_path) push_to_hub(convert_config, output_path)