diff --git a/fuse/data/tokenizers/modular_tokenizer/modular_tokenizer.py b/fuse/data/tokenizers/modular_tokenizer/modular_tokenizer.py index 2d57c107b..304925e21 100644 --- a/fuse/data/tokenizers/modular_tokenizer/modular_tokenizer.py +++ b/fuse/data/tokenizers/modular_tokenizer/modular_tokenizer.py @@ -868,19 +868,12 @@ def add_single_tokenizer( # we update the special tokens but do not save here. remember to save yourself. self.update_special_tokens( special_tokens=new_tokenize_special_tokens, - # save_tokenizer_path=self.cfg_raw["data"]["tokenizer"]["out_path"], ) def add_tokenizers( self, ) -> None: raise Exception("Not implemented") - # self.build_inner_decoder() - # if self._max_possible_token_id is not None: - # if self._get_max_mapped_id() > self._max_possible_token_id: - # raise Exception( - # f"tokenizer remapping resulted in IDs greater (max_id={self._get_max_mapped_id()}) than max_possible_id ({self._max_possible_token_id}). Reinitialize the modular tokenizer with larger max_possible_id" - # ) def _encode_single_type( self, @@ -1059,10 +1052,6 @@ def encode_list( merged_encoding = Encoding.merge(encoded_list) max_len = self.get_expected_max_len(override_max_len=max_len) - # if max_len is None: - # if self.max_len is not None: - # max_len = self.max_len - if max_len is not None: if len(merged_encoding) > max_len: overflow_info += f"OVERALL:{len(merged_encoding)}=>{max_len}|" diff --git a/fuse/data/tokenizers/modular_tokenizer/op.py b/fuse/data/tokenizers/modular_tokenizer/op.py index 261dccebd..67d5ff4d5 100644 --- a/fuse/data/tokenizers/modular_tokenizer/op.py +++ b/fuse/data/tokenizers/modular_tokenizer/op.py @@ -30,6 +30,7 @@ def __init__( validate_ends_with_eos: Optional[bool] = True, eos: Optional[str] = "", verbose: Optional[bool] = False, + on_unknown_default_value: str = "warn", **kwargs: Any, ) -> None: """ @@ -41,6 +42,7 @@ def __init__( validate_ends_with_eos: during encoder request (a _call_ to the op) will make sure that it ends with the provided eos token, and raise exception otherwise. having an eos (end of sentence) token in the end is useful for multiple scenarios, for example in a generative transformer (like T5 encoder-decoder) verbose: + on_unknown_default_value: User can define the default behavior of unknown token here in the constructor. In addition, this value can be overwritten in the __call__ """ super().__init__(**kwargs) @@ -60,6 +62,10 @@ def __init__( self._validate_ends_with_eos = validate_ends_with_eos self._eos = eos + self._on_unknown_default_value = on_unknown_default_value + + if on_unknown_default_value not in ["warn", "raise"]: + raise ValueError(f"Doesn't support {on_unknown_default_value=}!") if self._validate_ends_with_eos: eos_id = self._tokenizer.token_to_id(self._eos) @@ -211,7 +217,7 @@ def __call__( key_out_attention_mask: Optional[str] = None, convert_attention_mask_to_bool: Optional[bool] = True, max_seq_len: Optional[int] = None, - on_unknown: Optional[str] = "warn", + on_unknown: Optional[str] = None, verbose: Optional[int] = 1, validate_ends_with_eos: Optional[bool] = None, additional_caller_info_text: Optional[str] = "", @@ -230,7 +236,7 @@ def __call__( key_out_attention_mask (Optional[str], optional): _description_. Defaults to None. convert_attention_mask_to_bool (Optional[bool], optional): _description_. Defaults to True. max_seq_len (Optional[int], optional): set maximum sequence len dynamically, used for both padding and truncation.. Defaults to None. - on_unknown (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to ) are encountered: 'raise' or 'warn'. Defaults to "warn". + on_unknown (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to ) are encountered: 'raise' or 'warn'. Defaults to "warn". The default value can be determined in the constructor itself. verbose (Optional[int], optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning with full data. Defaults to 1. validate_ends_with_eos (Optional[bool], optional): if not None, overrides self._validate_ends_with_eos @@ -243,7 +249,6 @@ def __call__( Returns: NDict: _description_ """ - data = sample_dict[key_in] if not isinstance(data, (list, str)): # data is a list of named tuples of type collections.namedtuple("TypedInput", ["input_type", "input_string", "max_len"]) @@ -263,6 +268,10 @@ def __call__( f"validate_ends_with_eos was set to {validate_ends_with_eos}, but about to encode a string that does not end with {self._eos}. The str end was: {last_seq}" ) + if on_unknown is None: + # Use tokenizer instance default value + on_unknown = self._on_unknown_default_value + if isinstance(data, str): _ans = self._tokenizer.encode( data, @@ -510,6 +519,7 @@ def from_pretrained( identifier: str, pad_token: str = "", max_size: Optional[int] = None, + on_unknown_default_value: str = "warn", force_download: bool = False, resume_download: Optional[bool] = None, proxies: Optional[Dict] = None, @@ -549,7 +559,10 @@ def from_pretrained( ) from e tokenizer_op = cls( - tokenizer_path=identifier, pad_token=pad_token, max_size=max_size + tokenizer_path=identifier, + pad_token=pad_token, + max_size=max_size, + on_unknown_default_value=on_unknown_default_value, ) return tokenizer_op