Skip to content

Commit

Permalink
feat: s2s auto chat support (#779)
Browse files Browse the repository at this point in the history
  • Loading branch information
Elliott authored Oct 24, 2023
1 parent 6c77a8c commit e40e48f
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 54 deletions.
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


__version__ = "1.1.3"
__version__ = "1.1.4"

import sys
from typing import Any, List, Optional
Expand Down
6 changes: 4 additions & 2 deletions dataquality/dq_auto/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import pandas as pd
from datasets import Dataset, DatasetDict

from dataquality.integrations.seq2seq.formatter import BaseFormatter, DefaultFormatter
from dataquality.integrations.seq2seq.formatters.base import (
BaseFormatter,
DefaultFormatter,
)


@dataclass
Expand Down Expand Up @@ -49,7 +52,6 @@ class BaseAutoDatasetConfig:
input_col: str = "text"
target_col: str = "label"
# Dataset input / output formatter
max_train_size: Optional[int] = None
formatter: BaseFormatter = field(default_factory=DefaultFormatter)

def __post_init__(self) -> None:
Expand Down
32 changes: 26 additions & 6 deletions dataquality/integrations/seq2seq/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import dataquality as dq
from dataquality.dq_auto.base_data_manager import BaseDatasetManager
from dataquality.integrations.seq2seq.formatter import get_formatter
from dataquality.integrations.seq2seq.formatters import get_formatter
from dataquality.integrations.seq2seq.s2s_trainer import do_train, get_trainer
from dataquality.integrations.seq2seq.schema import (
Seq2SeqDatasetConfig,
Expand Down Expand Up @@ -74,6 +74,7 @@ def try_load_dataset_dict_from_config(
def get_dataset_dict_from_config(
self,
dataset_config: Optional[Seq2SeqDatasetConfig],
max_train_size: Optional[int] = None,
) -> Tuple[DatasetDict, Seq2SeqDatasetConfig]:
"""Creates and/or validates the DatasetDict provided by the user.
Expand Down Expand Up @@ -110,10 +111,25 @@ def get_dataset_dict_from_config(
if test_data is not None:
dd[Split.test] = self._convert_to_hf_dataset(test_data)

# Apply the datasets custom formatter on load dataset dict
dd = dd.map(dataset_config.formatter.format_sample)
dd = sample_dataset_dict(dd, dataset_config)
return self._validate_dataset_dict(dd, []), dataset_config
# Minimize dataset if user provided a max_train_size
dd = sample_dataset_dict(dd, dataset_config, max_train_size)
# Add validation data if missing, add 'id' column
dd = self._validate_dataset_dict(dd, [])
formatter = dataset_config.formatter
if formatter.process_batch:
# Apply the dataset's custom formatter on dataset dict
dd = dd.map(
formatter.format_batch,
batched=True,
remove_columns=dd[Split.train].column_names,
with_indices=True,
)
# We must re-add the id column if it's been dropped
dd = self._validate_dataset_dict(dd, [])
else:
dd = dd.map(formatter.format_sample, remove_columns=formatter.remove_cols)

return dd, dataset_config

def _validate_dataset_dict(
self,
Expand Down Expand Up @@ -161,6 +177,7 @@ def auto(
dataset_config: Optional[Seq2SeqDatasetConfig] = None,
training_config: Optional[Seq2SeqTrainingConfig] = None,
generation_config: Optional[Seq2SeqGenerationConfig] = None,
max_train_size: Optional[int] = None,
wait: bool = True,
) -> Optional[PreTrainedModel]:
"""Automatically get insights on a Seq2Seq dataset
Expand Down Expand Up @@ -191,6 +208,7 @@ def auto(
See `Seq2SeqTrainingConfig` for more details
:param generation_config: Optional config for generating predictions.
See `Seq2SeqGenerationConfig` for more details
:param max_train_size: Optional max number of training examples to use.
:param wait: Whether to wait for Galileo to complete processing your run.
Default True
Expand Down Expand Up @@ -230,7 +248,9 @@ def auto(
generation_config = generation_config or Seq2SeqGenerationConfig()

manager = S2SDatasetManager()
dd, dataset_config = manager.get_dataset_dict_from_config(dataset_config)
dd, dataset_config = manager.get_dataset_dict_from_config(
dataset_config, max_train_size
)

if not run_name and isinstance(dataset_config.hf_data, str):
run_name = run_name_from_hf_dataset(dataset_config.hf_data)
Expand Down
19 changes: 19 additions & 0 deletions dataquality/integrations/seq2seq/formatters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Dict, Type

from dataquality.integrations.seq2seq.formatters.alpaca import AlpacaFormatter
from dataquality.integrations.seq2seq.formatters.base import (
BaseFormatter,
DefaultFormatter,
)

FORMATTER_MAPPING: Dict[str, Type[BaseFormatter]] = {
AlpacaFormatter.name: AlpacaFormatter,
}


def get_formatter(name: str) -> BaseFormatter:
"""Returns the formatter for the given name
If the name isn't found, returns the base formatter
"""
return FORMATTER_MAPPING.get(name, DefaultFormatter)() # type: ignore
Original file line number Diff line number Diff line change
@@ -1,30 +1,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Optional, Type
from typing import Dict, List, Optional


@dataclass
class BaseFormatter(ABC):
name: str
input_col: str
target_col: str
max_train_size: Optional[int] = None

@abstractmethod
def format_sample(self, sample: Dict[str, str]) -> Dict[str, str]:
"""Base formatter is identity function"""
pass


@dataclass
class DefaultFormatter(BaseFormatter):
name: str = "default"
input_col: str = "text"
target_col: str = "label"

def format_sample(self, sample: Dict[str, str]) -> Dict[str, str]:
"""Base formatter is identity function"""
return sample
from dataquality.integrations.seq2seq.formatters.base import BaseFormatter


@dataclass
Expand All @@ -34,7 +11,13 @@ class AlpacaFormatter(BaseFormatter):
target_col: str = "output"
max_train_size: int = 1000

def format_sample(self, sample: Dict[str, str]) -> Dict[str, str]:
@property
def remove_cols(self) -> List[str]:
return ["input", "text"]

def format_sample(
self, sample: Dict[str, str], idx: Optional[int] = None
) -> Dict[str, str]:
"""Formats the alpaca dataset for seq2seq
Example:
Expand All @@ -43,7 +26,7 @@ def format_sample(self, sample: Dict[str, str]) -> Dict[str, str]:
... "input": "The quick brown fox jumped over the lazy dog.",
... "target": "The quick brown fox jumped over the lazy dog.",
... }
>>> format_alpaca(sample)
>>> AlpacaFormatter().format_sample(sample)
{
"formatted_input": (
"Human: Summarize the following paragraph "
Expand All @@ -57,16 +40,3 @@ def format_sample(self, sample: Dict[str, str]) -> Dict[str, str]:
return {
"formatted_input": f"{instruction} {context}",
}


FORMATTER_MAPPING: Dict[str, Type[BaseFormatter]] = {
AlpacaFormatter.name: AlpacaFormatter,
}


def get_formatter(name: str) -> BaseFormatter:
"""Returns the formatter for the given name
If the name isn't found, returns the base formatter
"""
return FORMATTER_MAPPING.get(name, DefaultFormatter)() # type: ignore
61 changes: 61 additions & 0 deletions dataquality/integrations/seq2seq/formatters/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional


@dataclass
class BatchData:
batch: Dict[str, Any]

def sample_from_idx(self, idx: int) -> Dict[str, Any]:
"""Gets a subset of the batch"""
sample = {}
for k, v in self.batch.items():
sample[k] = v[idx]
return sample


@dataclass
class BaseFormatter(ABC):
name: str
input_col: str
target_col: str
max_train_size: Optional[int] = None
process_batch: bool = False

@property
def remove_cols(self) -> List[str]:
return []

def format_batch(self, batch: Dict, idxs: List[int]) -> Dict[str, List]:
"""Formats a batch of chat data for seq2seq"""
result: Dict[str, List] = defaultdict(list)
batch_data = BatchData(batch)
for idx in idxs:
formatted_sample = self.format_sample(batch_data.sample_from_idx(idx), idx)
# formatted_sample returns one or more samples per idx, we add to result
for k, v in formatted_sample.items():
result[k] += v

return result

@abstractmethod
def format_sample(
self, sample: Dict[str, Any], idx: Optional[int] = None
) -> Dict[str, Any]:
"""Base formatter is identity function"""
pass


@dataclass
class DefaultFormatter(BaseFormatter):
name: str = "default"
input_col: str = "text"
target_col: str = "label"

def format_sample(
self, sample: Dict[str, Any], idx: Optional[int] = None
) -> Dict[str, Any]:
"""Base formatter is identity function"""
return sample
98 changes: 98 additions & 0 deletions dataquality/integrations/seq2seq/formatters/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from dataquality.integrations.seq2seq.formatters.base import BaseFormatter


@dataclass
class ChatFormatter(BaseFormatter):
name: str = "chat"
input_col: str = "input"
target_col: str = "target"
max_train_size: Optional[int] = None
process_batch: bool = True
# Sample level chat cols
turns_col: str = "turns"
metadata_col: str = "metadata"
# Turn level chat cols
content_col: str = "content"
role_col: str = "role"
# Chat roles
user: str = "User"
assistant: str = "Chatbot"

def format_sample(
self, sample: Dict[str, Any], idx: Optional[int] = None
) -> Dict[str, Any]:
"""Formats a chat dataset for seq2seq
Takes in a sample with "turns" column and explodes it to have one row
per turn.
Example:
>>> sample = {
... "turns": [
... {"role": "User", "content": "Hello"},
... {"role": "Chatbot", "content": "Hi"},
... {"role": "User", "content": "How are you?"},
... {"role": "Chatbot", "content": "I'm good, how are you?"},
... ],
... "metadata": {"unique_id": 1234, "dataset": "test"},
... "score": 0.5,
... }
>>> ChatFormatter().format_sample(sample, 5)
{
"chat_id": [5, 5],
"turn_id": [1, 2],
"input": ["Hello", "How are you?"],
"target": ["Hi", "I'm good, how are you?"],
"unique_id": [1234, 1234],
"dataset": ["test", "test"],
}
"""
unraveled_turns: Dict[str, Any] = defaultdict(list)
valid_meta_types = (str, int, float, bool)
turns: List[Dict[str, Any]] = sample[self.turns_col]

# # Add metadata and sample level cols to each turn
metadata: Dict[str, Any] = sample.get(self.metadata_col, {})
for k, v in sample.items():
if k not in [self.metadata_col, self.turns_col, "id"]:
metadata[k] = v

turn_data: Dict[str, Any] = {}
turn_id = 1
turn_default_cols = [self.role_col, self.content_col]
for turn in turns:
role = turn[self.role_col]
content = turn[self.content_col]
# Add metadata to each turn
turn_meta = {
f"{role}_{col}": turn[col]
for col in turn.keys()
if col not in turn_default_cols
and isinstance(turn[col], valid_meta_types)
}
# Add turn level metadata to turn
# NOTE: When we drop p3.8 we can use 'turn_data |= turn_meta'
turn_data.update(turn_meta)

if role == self.user:
turn_data[self.input_col] = content
elif role == self.assistant:
turn_data[self.target_col] = content
turn_data["turn_id"] = turn_id
turn_data["chat_id"] = idx
# Add sample level metadata
# NOTE: When we drop p3.8 we can use 'turn_data |= turn_meta'
turn_data.update(metadata)
for k, v in turn_data.items():
unraveled_turns[k].append(v)
# Reset turn data
turn_data = {}
turn_id += 1
else:
raise ValueError(f"Role {role} not recognized")

return unraveled_turns
2 changes: 1 addition & 1 deletion dataquality/loggers/data_logger/base_data_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def validate_metadata(self, batch_size: int) -> None:
)
continue
# Values must be a point, not an iterable
valid_types = (str, int, float, np.floating, np.integer)
valid_types = (str, int, float, bool, np.floating, np.integer)
invalid_values = filter(lambda t: not isinstance(t, valid_types), values)
bad_val = next(invalid_values, None)
if bad_val:
Expand Down
8 changes: 4 additions & 4 deletions dataquality/utils/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


def sample_dataset_dict(
dd: DatasetDict, dataset_config: BaseAutoDatasetConfig
dd: DatasetDict,
dataset_config: BaseAutoDatasetConfig,
max_train_size: Optional[int] = None,
) -> DatasetDict:
"""Samples the dataset dict to the max train size
Expand All @@ -26,9 +28,7 @@ def sample_dataset_dict(
- We set max eval size to be 25% of max train size
- Test and inference data are not sampled
"""
max_train_sz = (
dataset_config.max_train_size or dataset_config.formatter.max_train_size
)
max_train_sz = max_train_size or dataset_config.formatter.max_train_size
if not max_train_sz:
return dd

Expand Down

0 comments on commit e40e48f

Please sign in to comment.