Skip to content

Commit

Permalink
feat: s2s auto metadata (#777)
Browse files Browse the repository at this point in the history
  • Loading branch information
Elliott authored Oct 18, 2023
1 parent b335d1d commit 672c905
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 20 deletions.
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


__version__ = "1.1.1"
__version__ = "1.1.2"

import sys
from typing import Any, List, Optional
Expand Down
4 changes: 2 additions & 2 deletions dataquality/dq_auto/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from dataquality.utils.auto import (
add_class_label_to_dataset,
add_val_data_if_missing,
get_meta_cols,
run_name_from_hf_dataset,
)
from dataquality.utils.auto_trainer import do_train
from dataquality.utils.setfit import _get_meta_cols

a = Analytics(ApiClient, dq.config)
a.log_import("auto_tc")
Expand Down Expand Up @@ -105,7 +105,7 @@ def _get_labels(dd: DatasetDict, labels: Optional[List[str]] = None) -> List[str
def _log_dataset_dict(dd: DatasetDict) -> None:
for key in dd:
ds = dd[key]
meta = _get_meta_cols(ds.features)
meta = get_meta_cols(ds.features)
if key in Split.get_valid_keys():
dq.log_dataset(ds, meta=meta, split=key)
else:
Expand Down
5 changes: 3 additions & 2 deletions dataquality/integrations/seq2seq/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dataquality.schemas.task_type import TaskType
from dataquality.utils.auto import (
add_val_data_if_missing,
get_meta_cols,
run_name_from_hf_dataset,
)
from dataquality.utils.torch import cleanup_cuda
Expand Down Expand Up @@ -138,12 +139,12 @@ def _log_dataset_dict(dd: DatasetDict, input_col: str, target_col: str) -> None:
for key in dd.keys():
ds: Dataset = dd[key]
if key in Split.get_valid_keys():
meta = get_meta_cols(ds.features, {input_col, target_col})
if input_col != "text" and "text" in ds.column_names:
ds = ds.rename_columns({"text": "_metadata_text"})
if target_col != "label" and "label" in ds.column_names:
ds = ds.rename_columns({"label": "_metadata_label"})

dq.log_dataset(ds, text=input_col, label=target_col, split=key)
dq.log_dataset(ds, text=input_col, label=target_col, split=key, meta=meta)


def auto(
Expand Down
11 changes: 7 additions & 4 deletions dataquality/integrations/setfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
)
from dataquality.schemas.split import Split
from dataquality.schemas.task_type import TaskType
from dataquality.utils.auto import _apply_column_mapping, run_name_from_hf_dataset
from dataquality.utils.auto import (
_apply_column_mapping,
get_meta_cols,
run_name_from_hf_dataset,
)
from dataquality.utils.patcher import PatchManager
from dataquality.utils.setfit import (
SetFitModelHook,
_get_meta_cols,
_prepare_config,
_setup_patches,
get_trainer,
Expand Down Expand Up @@ -346,7 +349,7 @@ def do_model_eval(
for split in [Split.train, Split.test, Split.val]:
if split in encoded_data:
ds = encoded_data[split]
meta_columns = _get_meta_cols(ds.column_names)
meta_columns = get_meta_cols(ds.column_names)
dq_evaluate(
encoded_data[split],
split=split,
Expand All @@ -358,7 +361,7 @@ def do_model_eval(
inf_names = [k for k in encoded_data if k not in Split.get_valid_keys()]
for inf_name in inf_names:
ds = encoded_data[inf_name]
meta_columns = _get_meta_cols(ds.column_names)
meta_columns = get_meta_cols(ds.column_names)
dq_evaluate(
ds,
split=Split.inference, # type: ignore
Expand Down
1 change: 1 addition & 0 deletions dataquality/loggers/data_logger/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _get_input_df(self) -> DataFrame:
C.split_.value: [self.split] * len(self.ids),
C.token_label_positions.value: pa.array(self.token_label_positions),
C.token_label_offsets.value: pa.array(self.token_label_offsets),
**self.meta,
}
)

Expand Down
13 changes: 12 additions & 1 deletion dataquality/utils/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import warnings
from datetime import datetime
from typing import Dict, List, Optional, Union
from typing import Dict, Iterable, List, Optional, Set, Union

import pandas as pd
from datasets import ClassLabel, Dataset, DatasetDict, load_dataset
Expand All @@ -14,6 +14,17 @@
from dataquality.utils.name import BAD_CHARS_REGEX


def get_meta_cols(
cols: Iterable, reserved_cols: Optional[Set[str]] = None
) -> List[str]:
"""Returns the meta columns of a dataset."""
reserved_cols = reserved_cols or set()
default_cols = {"text", "label", "id"}
default_cols = set(reserved_cols).union(default_cols)
meta_columns = [col for col in cols if col not in default_cols]
return list(meta_columns)


def load_data_from_str(data: str) -> Union[pd.DataFrame, Dataset]:
"""Loads string data from either hf or disk.
Expand Down
12 changes: 3 additions & 9 deletions dataquality/utils/setfit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -10,6 +10,7 @@
import dataquality as dq
from dataquality.schemas.split import Split
from dataquality.schemas.task_type import TaskType
from dataquality.utils.auto import get_meta_cols
from dataquality.utils.patcher import Patch, PatchManager

BATCH_LOG_SIZE = 10_000
Expand All @@ -35,13 +36,6 @@ def clear(self) -> None:
self.meta.clear()


def _get_meta_cols(cols: Iterable) -> List[str]:
"""Returns the meta columns of a dataset."""
default_cols = ["text", "label", "id"]
meta_columns = [col for col in cols if col not in default_cols]
return meta_columns


def log_preds_setfit(
model: "SetFitModel",
dataset: Dataset,
Expand Down Expand Up @@ -79,7 +73,7 @@ def log_preds_setfit(
skip_logging = logger_config.helper_data[f"setfit_skip_input_log_{split}"]
# Iterate over the dataset in batches and log the data samples
# and model outputs
meta = _get_meta_cols(dataset.column_names)
meta = get_meta_cols(dataset.column_names)
for i in range(0, len(dataset), batch_size):
batch = dataset[i : i + batch_size]
assert text_col in batch, f"column '{text_col}' must be in batch"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ test = [
"setfit",
"accelerate>=0.19.0",
"typing-inspect==0.8.0",
"typing-extensions==4.0.0",
"typing-extensions==4.0.1",
"lightning",
]
dev = [
Expand Down

0 comments on commit 672c905

Please sign in to comment.