Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tokenizer's default behavior on unknown token #375

Merged
merged 4 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions fuse/data/tokenizers/modular_tokenizer/modular_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,19 +868,12 @@ def add_single_tokenizer(
# we update the special tokens but do not save here. remember to save yourself.
self.update_special_tokens(
special_tokens=new_tokenize_special_tokens,
# save_tokenizer_path=self.cfg_raw["data"]["tokenizer"]["out_path"],
)

def add_tokenizers(
self,
) -> None:
raise Exception("Not implemented")
# self.build_inner_decoder()
# if self._max_possible_token_id is not None:
# if self._get_max_mapped_id() > self._max_possible_token_id:
# raise Exception(
# f"tokenizer remapping resulted in IDs greater (max_id={self._get_max_mapped_id()}) than max_possible_id ({self._max_possible_token_id}). Reinitialize the modular tokenizer with larger max_possible_id"
# )

def _encode_single_type(
self,
Expand Down Expand Up @@ -1059,10 +1052,6 @@ def encode_list(
merged_encoding = Encoding.merge(encoded_list)

max_len = self.get_expected_max_len(override_max_len=max_len)
# if max_len is None:
# if self.max_len is not None:
# max_len = self.max_len

if max_len is not None:
if len(merged_encoding) > max_len:
overflow_info += f"OVERALL:{len(merged_encoding)}=>{max_len}|"
Expand Down
21 changes: 17 additions & 4 deletions fuse/data/tokenizers/modular_tokenizer/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
validate_ends_with_eos: Optional[bool] = True,
eos: Optional[str] = "<EOS>",
verbose: Optional[bool] = False,
on_unknown_default_value: str = "warn",
**kwargs: Any,
) -> None:
"""
Expand All @@ -41,6 +42,7 @@ def __init__(
validate_ends_with_eos: during encoder request (a _call_ to the op) will make sure that it ends with the provided eos token, and raise exception otherwise.
having an eos (end of sentence) token in the end is useful for multiple scenarios, for example in a generative transformer (like T5 encoder-decoder)
verbose:
on_unknown_default_value: User can define the default behavior of unknown token here in the constructor. In addition, this value can be overwritten in the __call__
"""
super().__init__(**kwargs)

Expand All @@ -60,6 +62,10 @@ def __init__(

self._validate_ends_with_eos = validate_ends_with_eos
self._eos = eos
self._on_unknown_default_value = on_unknown_default_value

if on_unknown_default_value not in ["warn", "raise"]:
raise ValueError(f"Doesn't support {on_unknown_default_value=}!")

if self._validate_ends_with_eos:
eos_id = self._tokenizer.token_to_id(self._eos)
Expand Down Expand Up @@ -211,7 +217,7 @@ def __call__(
key_out_attention_mask: Optional[str] = None,
convert_attention_mask_to_bool: Optional[bool] = True,
max_seq_len: Optional[int] = None,
on_unknown: Optional[str] = "warn",
on_unknown: Optional[str] = None,
verbose: Optional[int] = 1,
validate_ends_with_eos: Optional[bool] = None,
additional_caller_info_text: Optional[str] = "",
Expand All @@ -230,7 +236,7 @@ def __call__(
key_out_attention_mask (Optional[str], optional): _description_. Defaults to None.
convert_attention_mask_to_bool (Optional[bool], optional): _description_. Defaults to True.
max_seq_len (Optional[int], optional): set maximum sequence len dynamically, used for both padding and truncation.. Defaults to None.
on_unknown (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to <UNK>) are encountered: 'raise' or 'warn'. Defaults to "warn".
on_unknown (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to <UNK>) are encountered: 'raise' or 'warn'. Defaults to "warn". The default value can be determined in the constructor itself.
verbose (Optional[int], optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning
with full data. Defaults to 1.
validate_ends_with_eos (Optional[bool], optional): if not None, overrides self._validate_ends_with_eos
Expand All @@ -243,7 +249,6 @@ def __call__(
Returns:
NDict: _description_
"""

data = sample_dict[key_in]
if not isinstance(data, (list, str)):
# data is a list of named tuples of type collections.namedtuple("TypedInput", ["input_type", "input_string", "max_len"])
Expand All @@ -263,6 +268,10 @@ def __call__(
f"validate_ends_with_eos was set to {validate_ends_with_eos}, but about to encode a string that does not end with {self._eos}. The str end was: {last_seq}"
)

if on_unknown is None:
# Use tokenizer instance default value
on_unknown = self._on_unknown_default_value

if isinstance(data, str):
_ans = self._tokenizer.encode(
data,
Expand Down Expand Up @@ -510,6 +519,7 @@ def from_pretrained(
identifier: str,
pad_token: str = "<PAD>",
max_size: Optional[int] = None,
on_unknown_default_value: str = "warn",
force_download: bool = False,
resume_download: Optional[bool] = None,
proxies: Optional[Dict] = None,
Expand Down Expand Up @@ -549,7 +559,10 @@ def from_pretrained(
) from e

tokenizer_op = cls(
tokenizer_path=identifier, pad_token=pad_token, max_size=max_size
tokenizer_path=identifier,
pad_token=pad_token,
max_size=max_size,
on_unknown_default_value=on_unknown_default_value,
)
return tokenizer_op

Expand Down
Loading