From 672c905e0ffbe2118093ef6fa1e623eac72ad564 Mon Sep 17 00:00:00 2001 From: Elliott Date: Wed, 18 Oct 2023 10:57:26 -0400 Subject: [PATCH] feat: s2s auto metadata (#777) --- dataquality/__init__.py | 2 +- dataquality/dq_auto/text_classification.py | 4 ++-- dataquality/integrations/seq2seq/auto.py | 5 +++-- dataquality/integrations/setfit.py | 11 +++++++---- dataquality/loggers/data_logger/seq2seq.py | 1 + dataquality/utils/auto.py | 13 ++++++++++++- dataquality/utils/setfit.py | 12 +++--------- pyproject.toml | 2 +- 8 files changed, 30 insertions(+), 20 deletions(-) diff --git a/dataquality/__init__.py b/dataquality/__init__.py index 4b9480777..46fa2b853 100644 --- a/dataquality/__init__.py +++ b/dataquality/__init__.py @@ -31,7 +31,7 @@ """ -__version__ = "1.1.1" +__version__ = "1.1.2" import sys from typing import Any, List, Optional diff --git a/dataquality/dq_auto/text_classification.py b/dataquality/dq_auto/text_classification.py index 68534a7f3..883b55980 100644 --- a/dataquality/dq_auto/text_classification.py +++ b/dataquality/dq_auto/text_classification.py @@ -14,10 +14,10 @@ from dataquality.utils.auto import ( add_class_label_to_dataset, add_val_data_if_missing, + get_meta_cols, run_name_from_hf_dataset, ) from dataquality.utils.auto_trainer import do_train -from dataquality.utils.setfit import _get_meta_cols a = Analytics(ApiClient, dq.config) a.log_import("auto_tc") @@ -105,7 +105,7 @@ def _get_labels(dd: DatasetDict, labels: Optional[List[str]] = None) -> List[str def _log_dataset_dict(dd: DatasetDict) -> None: for key in dd: ds = dd[key] - meta = _get_meta_cols(ds.features) + meta = get_meta_cols(ds.features) if key in Split.get_valid_keys(): dq.log_dataset(ds, meta=meta, split=key) else: diff --git a/dataquality/integrations/seq2seq/auto.py b/dataquality/integrations/seq2seq/auto.py index 17a743d7a..bfc6e0b29 100644 --- a/dataquality/integrations/seq2seq/auto.py +++ b/dataquality/integrations/seq2seq/auto.py @@ -17,6 +17,7 @@ from dataquality.schemas.task_type import TaskType from dataquality.utils.auto import ( add_val_data_if_missing, + get_meta_cols, run_name_from_hf_dataset, ) from dataquality.utils.torch import cleanup_cuda @@ -138,12 +139,12 @@ def _log_dataset_dict(dd: DatasetDict, input_col: str, target_col: str) -> None: for key in dd.keys(): ds: Dataset = dd[key] if key in Split.get_valid_keys(): + meta = get_meta_cols(ds.features, {input_col, target_col}) if input_col != "text" and "text" in ds.column_names: ds = ds.rename_columns({"text": "_metadata_text"}) if target_col != "label" and "label" in ds.column_names: ds = ds.rename_columns({"label": "_metadata_label"}) - - dq.log_dataset(ds, text=input_col, label=target_col, split=key) + dq.log_dataset(ds, text=input_col, label=target_col, split=key, meta=meta) def auto( diff --git a/dataquality/integrations/setfit.py b/dataquality/integrations/setfit.py index 56722375a..ecbfef348 100644 --- a/dataquality/integrations/setfit.py +++ b/dataquality/integrations/setfit.py @@ -17,11 +17,14 @@ ) from dataquality.schemas.split import Split from dataquality.schemas.task_type import TaskType -from dataquality.utils.auto import _apply_column_mapping, run_name_from_hf_dataset +from dataquality.utils.auto import ( + _apply_column_mapping, + get_meta_cols, + run_name_from_hf_dataset, +) from dataquality.utils.patcher import PatchManager from dataquality.utils.setfit import ( SetFitModelHook, - _get_meta_cols, _prepare_config, _setup_patches, get_trainer, @@ -346,7 +349,7 @@ def do_model_eval( for split in [Split.train, Split.test, Split.val]: if split in encoded_data: ds = encoded_data[split] - meta_columns = _get_meta_cols(ds.column_names) + meta_columns = get_meta_cols(ds.column_names) dq_evaluate( encoded_data[split], split=split, @@ -358,7 +361,7 @@ def do_model_eval( inf_names = [k for k in encoded_data if k not in Split.get_valid_keys()] for inf_name in inf_names: ds = encoded_data[inf_name] - meta_columns = _get_meta_cols(ds.column_names) + meta_columns = get_meta_cols(ds.column_names) dq_evaluate( ds, split=Split.inference, # type: ignore diff --git a/dataquality/loggers/data_logger/seq2seq.py b/dataquality/loggers/data_logger/seq2seq.py index bd9811c78..682db9179 100644 --- a/dataquality/loggers/data_logger/seq2seq.py +++ b/dataquality/loggers/data_logger/seq2seq.py @@ -132,6 +132,7 @@ def _get_input_df(self) -> DataFrame: C.split_.value: [self.split] * len(self.ids), C.token_label_positions.value: pa.array(self.token_label_positions), C.token_label_offsets.value: pa.array(self.token_label_offsets), + **self.meta, } ) diff --git a/dataquality/utils/auto.py b/dataquality/utils/auto.py index f81ec5464..3a6d41fac 100644 --- a/dataquality/utils/auto.py +++ b/dataquality/utils/auto.py @@ -3,7 +3,7 @@ import re import warnings from datetime import datetime -from typing import Dict, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Set, Union import pandas as pd from datasets import ClassLabel, Dataset, DatasetDict, load_dataset @@ -14,6 +14,17 @@ from dataquality.utils.name import BAD_CHARS_REGEX +def get_meta_cols( + cols: Iterable, reserved_cols: Optional[Set[str]] = None +) -> List[str]: + """Returns the meta columns of a dataset.""" + reserved_cols = reserved_cols or set() + default_cols = {"text", "label", "id"} + default_cols = set(reserved_cols).union(default_cols) + meta_columns = [col for col in cols if col not in default_cols] + return list(meta_columns) + + def load_data_from_str(data: str) -> Union[pd.DataFrame, Dataset]: """Loads string data from either hf or disk. diff --git a/dataquality/utils/setfit.py b/dataquality/utils/setfit.py index 76fd72d82..c637aea40 100644 --- a/dataquality/utils/setfit.py +++ b/dataquality/utils/setfit.py @@ -1,6 +1,6 @@ import uuid from dataclasses import asdict, dataclass, field -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -10,6 +10,7 @@ import dataquality as dq from dataquality.schemas.split import Split from dataquality.schemas.task_type import TaskType +from dataquality.utils.auto import get_meta_cols from dataquality.utils.patcher import Patch, PatchManager BATCH_LOG_SIZE = 10_000 @@ -35,13 +36,6 @@ def clear(self) -> None: self.meta.clear() -def _get_meta_cols(cols: Iterable) -> List[str]: - """Returns the meta columns of a dataset.""" - default_cols = ["text", "label", "id"] - meta_columns = [col for col in cols if col not in default_cols] - return meta_columns - - def log_preds_setfit( model: "SetFitModel", dataset: Dataset, @@ -79,7 +73,7 @@ def log_preds_setfit( skip_logging = logger_config.helper_data[f"setfit_skip_input_log_{split}"] # Iterate over the dataset in batches and log the data samples # and model outputs - meta = _get_meta_cols(dataset.column_names) + meta = get_meta_cols(dataset.column_names) for i in range(0, len(dataset), batch_size): batch = dataset[i : i + batch_size] assert text_col in batch, f"column '{text_col}' must be in batch" diff --git a/pyproject.toml b/pyproject.toml index 7deeda8e7..a0aaf3b73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,7 @@ test = [ "setfit", "accelerate>=0.19.0", "typing-inspect==0.8.0", - "typing-extensions==4.0.0", + "typing-extensions==4.0.1", "lightning", ] dev = [