-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Elliott
authored
Oct 24, 2023
1 parent
6c77a8c
commit e40e48f
Showing
9 changed files
with
224 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,7 @@ | |
""" | ||
|
||
|
||
__version__ = "1.1.3" | ||
__version__ = "1.1.4" | ||
|
||
import sys | ||
from typing import Any, List, Optional | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import Dict, Type | ||
|
||
from dataquality.integrations.seq2seq.formatters.alpaca import AlpacaFormatter | ||
from dataquality.integrations.seq2seq.formatters.base import ( | ||
BaseFormatter, | ||
DefaultFormatter, | ||
) | ||
|
||
FORMATTER_MAPPING: Dict[str, Type[BaseFormatter]] = { | ||
AlpacaFormatter.name: AlpacaFormatter, | ||
} | ||
|
||
|
||
def get_formatter(name: str) -> BaseFormatter: | ||
"""Returns the formatter for the given name | ||
If the name isn't found, returns the base formatter | ||
""" | ||
return FORMATTER_MAPPING.get(name, DefaultFormatter)() # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from abc import ABC, abstractmethod | ||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, List, Optional | ||
|
||
|
||
@dataclass | ||
class BatchData: | ||
batch: Dict[str, Any] | ||
|
||
def sample_from_idx(self, idx: int) -> Dict[str, Any]: | ||
"""Gets a subset of the batch""" | ||
sample = {} | ||
for k, v in self.batch.items(): | ||
sample[k] = v[idx] | ||
return sample | ||
|
||
|
||
@dataclass | ||
class BaseFormatter(ABC): | ||
name: str | ||
input_col: str | ||
target_col: str | ||
max_train_size: Optional[int] = None | ||
process_batch: bool = False | ||
|
||
@property | ||
def remove_cols(self) -> List[str]: | ||
return [] | ||
|
||
def format_batch(self, batch: Dict, idxs: List[int]) -> Dict[str, List]: | ||
"""Formats a batch of chat data for seq2seq""" | ||
result: Dict[str, List] = defaultdict(list) | ||
batch_data = BatchData(batch) | ||
for idx in idxs: | ||
formatted_sample = self.format_sample(batch_data.sample_from_idx(idx), idx) | ||
# formatted_sample returns one or more samples per idx, we add to result | ||
for k, v in formatted_sample.items(): | ||
result[k] += v | ||
|
||
return result | ||
|
||
@abstractmethod | ||
def format_sample( | ||
self, sample: Dict[str, Any], idx: Optional[int] = None | ||
) -> Dict[str, Any]: | ||
"""Base formatter is identity function""" | ||
pass | ||
|
||
|
||
@dataclass | ||
class DefaultFormatter(BaseFormatter): | ||
name: str = "default" | ||
input_col: str = "text" | ||
target_col: str = "label" | ||
|
||
def format_sample( | ||
self, sample: Dict[str, Any], idx: Optional[int] = None | ||
) -> Dict[str, Any]: | ||
"""Base formatter is identity function""" | ||
return sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, List, Optional | ||
|
||
from dataquality.integrations.seq2seq.formatters.base import BaseFormatter | ||
|
||
|
||
@dataclass | ||
class ChatFormatter(BaseFormatter): | ||
name: str = "chat" | ||
input_col: str = "input" | ||
target_col: str = "target" | ||
max_train_size: Optional[int] = None | ||
process_batch: bool = True | ||
# Sample level chat cols | ||
turns_col: str = "turns" | ||
metadata_col: str = "metadata" | ||
# Turn level chat cols | ||
content_col: str = "content" | ||
role_col: str = "role" | ||
# Chat roles | ||
user: str = "User" | ||
assistant: str = "Chatbot" | ||
|
||
def format_sample( | ||
self, sample: Dict[str, Any], idx: Optional[int] = None | ||
) -> Dict[str, Any]: | ||
"""Formats a chat dataset for seq2seq | ||
Takes in a sample with "turns" column and explodes it to have one row | ||
per turn. | ||
Example: | ||
>>> sample = { | ||
... "turns": [ | ||
... {"role": "User", "content": "Hello"}, | ||
... {"role": "Chatbot", "content": "Hi"}, | ||
... {"role": "User", "content": "How are you?"}, | ||
... {"role": "Chatbot", "content": "I'm good, how are you?"}, | ||
... ], | ||
... "metadata": {"unique_id": 1234, "dataset": "test"}, | ||
... "score": 0.5, | ||
... } | ||
>>> ChatFormatter().format_sample(sample, 5) | ||
{ | ||
"chat_id": [5, 5], | ||
"turn_id": [1, 2], | ||
"input": ["Hello", "How are you?"], | ||
"target": ["Hi", "I'm good, how are you?"], | ||
"unique_id": [1234, 1234], | ||
"dataset": ["test", "test"], | ||
} | ||
""" | ||
unraveled_turns: Dict[str, Any] = defaultdict(list) | ||
valid_meta_types = (str, int, float, bool) | ||
turns: List[Dict[str, Any]] = sample[self.turns_col] | ||
|
||
# # Add metadata and sample level cols to each turn | ||
metadata: Dict[str, Any] = sample.get(self.metadata_col, {}) | ||
for k, v in sample.items(): | ||
if k not in [self.metadata_col, self.turns_col, "id"]: | ||
metadata[k] = v | ||
|
||
turn_data: Dict[str, Any] = {} | ||
turn_id = 1 | ||
turn_default_cols = [self.role_col, self.content_col] | ||
for turn in turns: | ||
role = turn[self.role_col] | ||
content = turn[self.content_col] | ||
# Add metadata to each turn | ||
turn_meta = { | ||
f"{role}_{col}": turn[col] | ||
for col in turn.keys() | ||
if col not in turn_default_cols | ||
and isinstance(turn[col], valid_meta_types) | ||
} | ||
# Add turn level metadata to turn | ||
# NOTE: When we drop p3.8 we can use 'turn_data |= turn_meta' | ||
turn_data.update(turn_meta) | ||
|
||
if role == self.user: | ||
turn_data[self.input_col] = content | ||
elif role == self.assistant: | ||
turn_data[self.target_col] = content | ||
turn_data["turn_id"] = turn_id | ||
turn_data["chat_id"] = idx | ||
# Add sample level metadata | ||
# NOTE: When we drop p3.8 we can use 'turn_data |= turn_meta' | ||
turn_data.update(metadata) | ||
for k, v in turn_data.items(): | ||
unraveled_turns[k].append(v) | ||
# Reset turn data | ||
turn_data = {} | ||
turn_id += 1 | ||
else: | ||
raise ValueError(f"Role {role} not recognized") | ||
|
||
return unraveled_turns |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters