diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index f1da1f552075..0ef7f2edb689 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -512,7 +512,10 @@ def get_feature_extractor_dict( with open(resolved_feature_extractor_file, encoding="utf-8") as reader: text = reader.read() feature_extractor_dict = json.loads(text) - feature_extractor_dict = feature_extractor_dict.get("feature_extractor", feature_extractor_dict) + if "audio_processor" in feature_extractor_dict: + feature_extractor_dict = feature_extractor_dict["audio_processor"] + else: + feature_extractor_dict = feature_extractor_dict.get("feature_extractor", feature_extractor_dict) except json.JSONDecodeError: raise OSError( diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index cb2eb94cecd4..82ea2dcfc632 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -300,7 +300,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): processor_config_file = cached_file(pretrained_model_name_or_path, PROCESSOR_NAME, **cached_file_kwargs) if processor_config_file is not None: config_dict, _ = ProcessorMixin.get_processor_dict(pretrained_model_name_or_path, **kwargs) - processor_class = config_dict.get("processor_class", None) + processor_class = config_dict.get("processor_class") if "AutoProcessor" in config_dict.get("auto_map", {}): processor_auto_map = config_dict["auto_map"]["AutoProcessor"] diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 55844c8d9cce..643734872be6 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -524,7 +524,6 @@ class ProcessorMixin(PushToHubMixin): """ attributes = ["feature_extractor", "tokenizer"] - optional_attributes = ["chat_template", "audio_tokenizer"] optional_call_args: list[str] = [] # Names need to be attr_class for attr in attributes feature_extractor_class = None @@ -534,21 +533,18 @@ class ProcessorMixin(PushToHubMixin): # args have to match the attributes class attribute def __init__(self, *args, **kwargs): - # First, extract optional attributes from kwargs if present - # Optional attributes can never be positional arguments - for optional_attribute in self.optional_attributes: - optional_attribute_value = kwargs.pop(optional_attribute, None) - setattr(self, optional_attribute, optional_attribute_value) + # First, extract chat template from kwargs. It can never be a positional arg + setattr(self, "chat_template", kwargs.pop("chat_template", None)) - # Check audio tokenizer for its class but do not treat it as attr to avoid saving weights - if optional_attribute == "audio_tokenizer" and optional_attribute_value is not None: - proper_class = self.check_argument_for_proper_class(optional_attribute, optional_attribute_value) - - if not (is_torch_available() and isinstance(optional_attribute_value, PreTrainedAudioTokenizerBase)): - raise ValueError( - f"Tried to use `{proper_class}` for audio tokenization. However, this class is not" - " registered for audio tokenization." - ) + # Check audio tokenizer for its class but do not treat it as attr to avoid saving weights + if (audio_tokenizer := kwargs.pop("audio_tokenizer", None)) is not None: + proper_class = self.check_argument_for_proper_class("audio_tokenizer", audio_tokenizer) + if not (is_torch_available() and isinstance(audio_tokenizer, PreTrainedAudioTokenizerBase)): + raise ValueError( + f"Tried to use `{proper_class}` for audio tokenization. However, this class is not" + " registered for audio tokenization." + ) + setattr(self, "audio_tokenizer", audio_tokenizer) # Sanitize args and kwargs for key in kwargs: @@ -652,7 +648,7 @@ def check_argument_for_proper_class(self, argument_name, argument): return proper_class - def to_dict(self, legacy_serialization=True) -> dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """ Serializes this instance to a Python dictionary. @@ -664,23 +660,33 @@ def to_dict(self, legacy_serialization=True) -> dict[str, Any]: # Get the kwargs in `__init__`. sig = inspect.signature(self.__init__) # Only save the attributes that are presented in the kwargs of `__init__`. - attrs_to_save = list(sig.parameters) + # or in the attributes + attrs_to_save = list(sig.parameters) + self.__class__.attributes # extra attributes to be kept attrs_to_save += ["auto_map"] - if legacy_serialization: - # Don't save attributes like `tokenizer`, `image processor` etc. in processor config if `legacy=True` - attrs_to_save = [x for x in attrs_to_save if x not in self.__class__.attributes] - if "tokenizer" in output: del output["tokenizer"] if "qformer_tokenizer" in output: del output["qformer_tokenizer"] if "protein_tokenizer" in output: del output["protein_tokenizer"] + if "char_tokenizer" in output: + del output["char_tokenizer"] if "chat_template" in output: del output["chat_template"] + def save_public_processor_class(dictionary): + # make sure private name "_processor_class" is correctly + # saved as "processor_class" + _processor_class = dictionary.pop("_processor_class", None) + if _processor_class is not None: + dictionary["processor_class"] = _processor_class + for value in dictionary.values(): + if isinstance(value, dict): + save_public_processor_class(value) + return dictionary + def cast_array_to_list(dictionary): """ Numpy arrays are not serialiazable but can be in pre-processing dicts. @@ -693,6 +699,14 @@ def cast_array_to_list(dictionary): dictionary[key] = cast_array_to_list(value) return dictionary + # Special case, add `audio_tokenizer` dict which points to model weights and path + if "audio_tokenizer" in output: + audio_tokenizer_dict = { + "audio_tokenizer_class": self.audio_tokenizer.__class__.__name__, + "audio_tokenizer_name_or_path": self.audio_tokenizer.name_or_path, + } + output["audio_tokenizer"] = audio_tokenizer_dict + # Serialize attributes as a dict output = { k: v.to_dict() if isinstance(v, PushToHubMixin) else v @@ -700,38 +714,26 @@ def cast_array_to_list(dictionary): if ( k in attrs_to_save # keep all attributes that have to be serialized and v.__class__.__name__ != "BeamSearchDecoderCTC" # remove attributes with that are objects - and ( - (legacy_serialization and not isinstance(v, PushToHubMixin)) or not legacy_serialization - ) # remove `PushToHubMixin` objects ) } output = cast_array_to_list(output) - - # Special case, add `audio_tokenizer` dict which points to model weights and path - if not legacy_serialization and "audio_tokenizer" in output: - audio_tokenizer_dict = { - "audio_tokenizer_class": self.audio_tokenizer.__class__.__name__, - "audio_tokenizer_name_or_path": self.audio_tokenizer.name_or_path, - } - # Update or overwrite, what do audio tokenizers expect when loading? - output["audio_tokenizer"] = audio_tokenizer_dict - + output = save_public_processor_class(output) output["processor_class"] = self.__class__.__name__ return output - def to_json_string(self, legacy_serialization=True) -> str: + def to_json_string(self) -> str: """ Serializes this instance to a JSON string. Returns: `str`: String containing all the attributes that make up this feature_extractor instance in JSON format. """ - dictionary = self.to_dict(legacy_serialization=legacy_serialization) + dictionary = self.to_dict() return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" - def to_json_file(self, json_file_path: Union[str, os.PathLike], legacy_serialization=True): + def to_json_file(self, json_file_path: Union[str, os.PathLike]): """ Save this instance to a JSON file. @@ -740,14 +742,14 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike], legacy_serializa Path to the JSON file in which this processor instance's parameters will be saved. """ with open(json_file_path, "w", encoding="utf-8") as writer: - writer.write(self.to_json_string(legacy_serialization=legacy_serialization)) + writer.write(self.to_json_string()) def __repr__(self): attributes_repr = [f"- {name}: {repr(getattr(self, name))}" for name in self.attributes] attributes_repr = "\n".join(attributes_repr) return f"{self.__class__.__name__}:\n{attributes_repr}\n\n{self.to_json_string()}" - def save_pretrained(self, save_directory, push_to_hub: bool = False, legacy_serialization: bool = True, **kwargs): + def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): """ Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it can be reloaded using the [`~ProcessorMixin.from_pretrained`] method. @@ -768,10 +770,6 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, legacy_seri Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). - legacy_serialization (`bool`, *optional*, defaults to `True`): - Whether or not to save processor attributes in separate config files (legacy) or in processor's config - file as a nested dict. Saving all attributes in a single dict will become the default in future versions. - Set to `legacy_serialization=True` until then. kwargs (`dict[str, Any]`, *optional*): Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -806,20 +804,16 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, legacy_seri save_jinja_files = kwargs.get("save_jinja_files", True) for attribute_name in self.attributes: + attribute = getattr(self, attribute_name) + if hasattr(attribute, "_set_processor_class"): + attribute._set_processor_class(self.__class__.__name__) + # Save the tokenizer in its own vocab file. The other attributes are saved as part of `processor_config.json` if attribute_name == "tokenizer": - attribute = getattr(self, attribute_name) - if hasattr(attribute, "_set_processor_class"): - attribute._set_processor_class(self.__class__.__name__) - # Propagate save_jinja_files to tokenizer to ensure we don't get conflicts attribute.save_pretrained(save_directory, save_jinja_files=save_jinja_files) - elif legacy_serialization: - attribute = getattr(self, attribute_name) - # Include the processor class in attribute config so this processor can then be reloaded with `AutoProcessor` API. - if hasattr(attribute, "_set_processor_class"): - attribute._set_processor_class(self.__class__.__name__) - attribute.save_pretrained(save_directory) + elif attribute._auto_class is not None: + custom_object_save(attribute, save_directory, config=attribute) if self._auto_class is not None: # We added an attribute to the init_kwargs of the tokenizers, which needs to be cleaned up. @@ -832,9 +826,7 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, legacy_seri # plus we save chat_template in its own file output_processor_file = os.path.join(save_directory, PROCESSOR_NAME) output_chat_template_file_jinja = os.path.join(save_directory, CHAT_TEMPLATE_FILE) - output_chat_template_file_legacy = os.path.join( - save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE - ) # Legacy filename + output_chat_template_file_legacy = os.path.join(save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE) chat_template_dir = os.path.join(save_directory, CHAT_TEMPLATE_DIR) # Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict` @@ -877,39 +869,10 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, legacy_seri "separate files using the `save_jinja_files` argument." ) - if legacy_serialization: - output_audio_tokenizer_file = os.path.join(save_directory, AUDIO_TOKENIZER_NAME) - processor_dict = self.to_dict() - - # For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and - # `auto_map` is not specified. - if set(processor_dict.keys()) != {"processor_class"}: - self.to_json_file(output_processor_file) - logger.info(f"processor saved in {output_processor_file}") - - if set(processor_dict.keys()) == {"processor_class"}: - return_files = [] - else: - return_files = [output_processor_file] - - if self.audio_tokenizer is not None: - audio_tokenizer_class = self.audio_tokenizer.__class__.__name__ - audio_tokenizer_name_or_path = self.audio_tokenizer.name_or_path - audio_tokenizer_dict = { - "audio_tokenizer_class": audio_tokenizer_class, - "audio_tokenizer_name_or_path": audio_tokenizer_name_or_path, - } - audio_tokenizer_json = json.dumps(audio_tokenizer_dict, indent=2, sort_keys=True) + "\n" - with open(output_audio_tokenizer_file, "w", encoding="utf-8") as writer: - writer.write(audio_tokenizer_json) - # Create a unified `preprocessor_config.json` and save all attributes as a composite config, except for tokenizers - # NOTE: this will become the default way to save all processor attrbiutes in future versions. Toggled off for now to give - # us time for smoother transition - else: - self.to_json_file(output_processor_file, legacy_serialization=False) - logger.info(f"processor saved in {output_processor_file}") - return_files = [output_processor_file] + self.to_json_file(output_processor_file) + logger.info(f"processor saved in {output_processor_file}") + return_files = [output_processor_file] if push_to_hub: self._upload_modified_files( @@ -1169,10 +1132,6 @@ def get_processor_dict( audio_tokenizer_path, **audio_tokenizer_kwargs ) - # Pop attributes if saved in a single processor dict, they are loaded in `_get_arguments_from_pretrained` - for attribute in cls.attributes: - processor_dict.pop(attribute, None) - return processor_dict, kwargs @classmethod @@ -1196,12 +1155,9 @@ def from_args_and_dict(cls, args, processor_dict: dict[str, Any], **kwargs): return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) # We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs - # If we don't pop, some specific kwargs will raise a warning - if "processor_class" in processor_dict: - del processor_dict["processor_class"] - - if "auto_map" in processor_dict: - del processor_dict["auto_map"] + # If we don't pop, some specific kwargs will raise a warning or error + for unused_kwarg in cls.attributes + ["auto_map", "processor_class"]: + processor_dict.pop(unused_kwarg, None) # override processor_dict with given kwargs processor_dict.update(kwargs) @@ -1461,8 +1417,8 @@ def from_pretrained( if token is not None: kwargs["token"] = token - args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs) + args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) return cls.from_args_and_dict(args, processor_dict, **kwargs) @classmethod diff --git a/tests/models/align/test_processing_align.py b/tests/models/align/test_processing_align.py index 0adfc5a82205..01190c247a22 100644 --- a/tests/models/align/test_processing_align.py +++ b/tests/models/align/test_processing_align.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os import shutil import tempfile @@ -23,7 +22,7 @@ from transformers import BertTokenizer, BertTokenizerFast from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES from transformers.testing_utils import require_vision -from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available +from transformers.utils import is_vision_available from ...test_processing_common import ProcessorTesterMixin @@ -67,9 +66,9 @@ def setUp(self): "image_mean": [0.48145466, 0.4578275, 0.40821073], "image_std": [0.26862954, 0.26130258, 0.27577711], } - self.image_processor_file = os.path.join(self.tmpdirname, IMAGE_PROCESSOR_NAME) - with open(self.image_processor_file, "w", encoding="utf-8") as fp: - json.dump(image_processor_map, fp) + image_processor = EfficientNetImageProcessor(**image_processor_map) + processor = AlignProcessor(tokenizer=self.get_tokenizer(), image_processor=image_processor) + processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs): return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) diff --git a/tests/models/auto/test_processor_auto.py b/tests/models/auto/test_processor_auto.py index 6eabd690eed9..265abbd731c4 100644 --- a/tests/models/auto/test_processor_auto.py +++ b/tests/models/auto/test_processor_auto.py @@ -122,37 +122,6 @@ def test_processor_from_processor_class(self): self.assertIsInstance(processor, Wav2Vec2Processor) - def test_processor_from_feat_extr_processor_class(self): - with tempfile.TemporaryDirectory() as tmpdirname: - feature_extractor = Wav2Vec2FeatureExtractor() - tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h") - - processor = Wav2Vec2Processor(feature_extractor, tokenizer) - - # save in new folder - processor.save_pretrained(tmpdirname) - - if os.path.isfile(os.path.join(tmpdirname, PROCESSOR_NAME)): - # drop `processor_class` in processor - with open(os.path.join(tmpdirname, PROCESSOR_NAME)) as f: - config_dict = json.load(f) - config_dict.pop("processor_class") - - with open(os.path.join(tmpdirname, PROCESSOR_NAME), "w") as f: - f.write(json.dumps(config_dict)) - - # drop `processor_class` in tokenizer - with open(os.path.join(tmpdirname, TOKENIZER_CONFIG_FILE)) as f: - config_dict = json.load(f) - config_dict.pop("processor_class") - - with open(os.path.join(tmpdirname, TOKENIZER_CONFIG_FILE), "w") as f: - f.write(json.dumps(config_dict)) - - processor = AutoProcessor.from_pretrained(tmpdirname) - - self.assertIsInstance(processor, Wav2Vec2Processor) - def test_processor_from_tokenizer_processor_class(self): with tempfile.TemporaryDirectory() as tmpdirname: feature_extractor = Wav2Vec2FeatureExtractor() @@ -163,21 +132,11 @@ def test_processor_from_tokenizer_processor_class(self): # save in new folder processor.save_pretrained(tmpdirname) - if os.path.isfile(os.path.join(tmpdirname, PROCESSOR_NAME)): - # drop `processor_class` in processor - with open(os.path.join(tmpdirname, PROCESSOR_NAME)) as f: - config_dict = json.load(f) - config_dict.pop("processor_class") - - with open(os.path.join(tmpdirname, PROCESSOR_NAME), "w") as f: - f.write(json.dumps(config_dict)) - - # drop `processor_class` in feature extractor - with open(os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME)) as f: + # drop `processor_class` in processor + with open(os.path.join(tmpdirname, PROCESSOR_NAME)) as f: config_dict = json.load(f) config_dict.pop("processor_class") - - with open(os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME), "w") as f: + with open(os.path.join(tmpdirname, PROCESSOR_NAME), "w") as f: f.write(json.dumps(config_dict)) processor = AutoProcessor.from_pretrained(tmpdirname) diff --git a/tests/models/chinese_clip/test_processing_chinese_clip.py b/tests/models/chinese_clip/test_processing_chinese_clip.py index 5aef3d06c15b..dab0d37773c9 100644 --- a/tests/models/chinese_clip/test_processing_chinese_clip.py +++ b/tests/models/chinese_clip/test_processing_chinese_clip.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os import shutil import tempfile @@ -23,7 +22,7 @@ from transformers import BertTokenizer, BertTokenizerFast from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES from transformers.testing_utils import require_vision -from transformers.utils import FEATURE_EXTRACTOR_NAME, is_vision_available +from transformers.utils import is_vision_available from ...test_processing_common import ProcessorTesterMixin @@ -74,12 +73,8 @@ def setUpClass(cls): "image_std": [0.26862954, 0.26130258, 0.27577711], "do_convert_rgb": True, } - cls.image_processor_file = os.path.join(cls.tmpdirname, FEATURE_EXTRACTOR_NAME) - with open(cls.image_processor_file, "w", encoding="utf-8") as fp: - json.dump(image_processor_map, fp) - tokenizer = cls.get_tokenizer() - image_processor = cls.get_image_processor() + image_processor = ChineseCLIPImageProcessor(**image_processor_map) processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor) processor.save_pretrained(cls.tmpdirname) diff --git a/tests/models/clipseg/test_processing_clipseg.py b/tests/models/clipseg/test_processing_clipseg.py index f7255838caa8..98b59373c429 100644 --- a/tests/models/clipseg/test_processing_clipseg.py +++ b/tests/models/clipseg/test_processing_clipseg.py @@ -23,7 +23,7 @@ from transformers import CLIPTokenizer, CLIPTokenizerFast from transformers.models.clip.tokenization_clip import VOCAB_FILES_NAMES from transformers.testing_utils import require_vision -from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available +from transformers.utils import is_vision_available from ...test_processing_common import ProcessorTesterMixin @@ -60,9 +60,9 @@ def setUp(self): "image_mean": [0.48145466, 0.4578275, 0.40821073], "image_std": [0.26862954, 0.26130258, 0.27577711], } - self.image_processor_file = os.path.join(self.tmpdirname, IMAGE_PROCESSOR_NAME) - with open(self.image_processor_file, "w", encoding="utf-8") as fp: - json.dump(image_processor_map, fp) + image_processor = ViTImageProcessor(**image_processor_map) + processor = CLIPSegProcessor(tokenizer=self.get_tokenizer(), image_processor=image_processor) + processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs): return CLIPTokenizer.from_pretrained(self.tmpdirname, **kwargs) diff --git a/tests/models/flava/test_processing_flava.py b/tests/models/flava/test_processing_flava.py index 10a00a869915..b50da8e244fa 100644 --- a/tests/models/flava/test_processing_flava.py +++ b/tests/models/flava/test_processing_flava.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os import random import shutil @@ -24,7 +23,7 @@ from transformers import BertTokenizer, BertTokenizerFast from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES from transformers.testing_utils import require_vision -from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available +from transformers.utils import is_vision_available from ...test_processing_common import ProcessorTesterMixin @@ -76,9 +75,9 @@ def setUp(self): "codebook_image_std": FLAVA_CODEBOOK_STD, } - self.image_processor_file = os.path.join(self.tmpdirname, IMAGE_PROCESSOR_NAME) - with open(self.image_processor_file, "w", encoding="utf-8") as fp: - json.dump(image_processor_map, fp) + image_processor = FlavaImageProcessor(**image_processor_map) + processor = FlavaProcessor(tokenizer=self.get_tokenizer(), image_processor=image_processor) + processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs): return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) diff --git a/tests/models/granite_speech/test_processing_granite_speech.py b/tests/models/granite_speech/test_processing_granite_speech.py index 569ac9cfbc19..ca849ba10103 100644 --- a/tests/models/granite_speech/test_processing_granite_speech.py +++ b/tests/models/granite_speech/test_processing_granite_speech.py @@ -44,10 +44,10 @@ def setUp(self): processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs): - return AutoTokenizer.from_pretrained(self.checkpoint, **kwargs) + return AutoTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_audio_processor(self, **kwargs): - return GraniteSpeechFeatureExtractor.from_pretrained(self.checkpoint, **kwargs) + return GraniteSpeechFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs) def tearDown(self): shutil.rmtree(self.tmpdirname) diff --git a/tests/models/layoutlmv2/test_processing_layoutlmv2.py b/tests/models/layoutlmv2/test_processing_layoutlmv2.py index dc441b3030c0..9a116e54c9a7 100644 --- a/tests/models/layoutlmv2/test_processing_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_processing_layoutlmv2.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os import shutil import tempfile @@ -23,7 +22,7 @@ from transformers.models.layoutlmv2 import LayoutLMv2Processor, LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast from transformers.models.layoutlmv2.tokenization_layoutlmv2 import VOCAB_FILES_NAMES from transformers.testing_utils import require_pytesseract, require_tokenizers, require_torch, slow -from transformers.utils import FEATURE_EXTRACTOR_NAME, is_pytesseract_available +from transformers.utils import is_pytesseract_available from ...test_processing_common import ProcessorTesterMixin @@ -68,9 +67,10 @@ def setUp(self): self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - self.image_processing_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) - with open(self.image_processing_file, "w", encoding="utf-8") as fp: - fp.write(json.dumps(image_processor_map) + "\n") + + image_processor = LayoutLMv2ImageProcessor(**image_processor_map) + processor = LayoutLMv2Processor(tokenizer=self.get_tokenizer(), image_processor=image_processor) + processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer: return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) diff --git a/tests/models/layoutlmv3/test_processing_layoutlmv3.py b/tests/models/layoutlmv3/test_processing_layoutlmv3.py index cf44966327d3..b7a51a940a5b 100644 --- a/tests/models/layoutlmv3/test_processing_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_processing_layoutlmv3.py @@ -23,7 +23,7 @@ from transformers.models.layoutlmv3 import LayoutLMv3Processor, LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES from transformers.testing_utils import require_pytesseract, require_tokenizers, require_torch, slow -from transformers.utils import FEATURE_EXTRACTOR_NAME, is_pytesseract_available +from transformers.utils import is_pytesseract_available from ...test_processing_common import ProcessorTesterMixin @@ -81,9 +81,9 @@ def setUp(self): "apply_ocr": True, } - self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) - with open(self.feature_extraction_file, "w", encoding="utf-8") as fp: - fp.write(json.dumps(image_processor_map) + "\n") + image_processor = LayoutLMv3ImageProcessor(**image_processor_map) + processor = LayoutLMv3Processor(tokenizer=self.get_tokenizer(), image_processor=image_processor) + processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer: return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) diff --git a/tests/models/markuplm/test_processing_markuplm.py b/tests/models/markuplm/test_processing_markuplm.py index 59f7f365d693..42674ad61fb3 100644 --- a/tests/models/markuplm/test_processing_markuplm.py +++ b/tests/models/markuplm/test_processing_markuplm.py @@ -28,7 +28,7 @@ ) from transformers.models.markuplm.tokenization_markuplm import VOCAB_FILES_NAMES from transformers.testing_utils import require_bs4, require_tokenizers, require_torch, slow -from transformers.utils import FEATURE_EXTRACTOR_NAME, is_bs4_available, is_tokenizers_available +from transformers.utils import is_bs4_available, is_tokenizers_available if is_bs4_available(): @@ -64,10 +64,9 @@ def setUp(self): with open(self.tokenizer_config_file, "w", encoding="utf-8") as fp: fp.write(json.dumps({"tags_dict": self.tags_dict})) - feature_extractor_map = {"feature_extractor_type": "MarkupLMFeatureExtractor"} - self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) - with open(self.feature_extraction_file, "w", encoding="utf-8") as fp: - fp.write(json.dumps(feature_extractor_map) + "\n") + feature_extractor = MarkupLMFeatureExtractor() + processor = MarkupLMProcessor(tokenizer=self.get_tokenizer(), feature_extractor=feature_extractor) + processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer: return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) diff --git a/tests/models/mgp_str/test_processing_mgp_str.py b/tests/models/mgp_str/test_processing_mgp_str.py index a28e956bc6ec..17336d351211 100644 --- a/tests/models/mgp_str/test_processing_mgp_str.py +++ b/tests/models/mgp_str/test_processing_mgp_str.py @@ -25,7 +25,7 @@ from transformers import MgpstrTokenizer from transformers.models.mgp_str.tokenization_mgp_str import VOCAB_FILES_NAMES from transformers.testing_utils import require_torch, require_vision -from transformers.utils import IMAGE_PROCESSOR_NAME, is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_vision_available if is_torch_available(): @@ -65,9 +65,9 @@ def setUp(self): "resample": 3, "size": {"height": 32, "width": 128}, } - self.image_processor_file = os.path.join(self.tmpdirname, IMAGE_PROCESSOR_NAME) - with open(self.image_processor_file, "w", encoding="utf-8") as fp: - json.dump(image_processor_map, fp) + image_processor = ViTImageProcessor(**image_processor_map) + processor = MgpstrProcessor(tokenizer=self.get_tokenizer(), image_processor=image_processor) + processor.save_pretrained(self.tmpdirname) # We copy here rather than use the ProcessorTesterMixin as this processor has a `char_tokenizer` instead of a # tokenizer attribute, which means all the tests would need to be overridden. diff --git a/tests/models/oneformer/test_processing_oneformer.py b/tests/models/oneformer/test_processing_oneformer.py index 5e5a26e3a796..fbae54699727 100644 --- a/tests/models/oneformer/test_processing_oneformer.py +++ b/tests/models/oneformer/test_processing_oneformer.py @@ -406,7 +406,7 @@ def test_feat_extract_from_and_save_pretrained(self): with tempfile.TemporaryDirectory() as tmpdirname: feat_extract_first.save_pretrained(tmpdirname) - check_json_file_has_correct_format(os.path.join(tmpdirname, "preprocessor_config.json")) + check_json_file_has_correct_format(os.path.join(tmpdirname, "processor_config.json")) feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname) self.assertEqual(feat_extract_second.image_processor.to_dict(), feat_extract_first.image_processor.to_dict()) diff --git a/tests/models/owlvit/test_processing_owlvit.py b/tests/models/owlvit/test_processing_owlvit.py index 46a7881b786a..5370c38d33f4 100644 --- a/tests/models/owlvit/test_processing_owlvit.py +++ b/tests/models/owlvit/test_processing_owlvit.py @@ -23,7 +23,7 @@ from transformers import CLIPTokenizer, CLIPTokenizerFast from transformers.models.clip.tokenization_clip import VOCAB_FILES_NAMES from transformers.testing_utils import require_vision -from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available +from transformers.utils import is_vision_available from ...test_processing_common import ProcessorTesterMixin @@ -60,9 +60,9 @@ def setUp(self): "image_mean": [0.48145466, 0.4578275, 0.40821073], "image_std": [0.26862954, 0.26130258, 0.27577711], } - self.image_processor_file = os.path.join(self.tmpdirname, IMAGE_PROCESSOR_NAME) - with open(self.image_processor_file, "w", encoding="utf-8") as fp: - json.dump(image_processor_map, fp) + image_processor = OwlViTImageProcessor(**image_processor_map) + processor = OwlViTProcessor(tokenizer=self.get_tokenizer(), image_processor=image_processor) + processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs): return CLIPTokenizer.from_pretrained(self.tmpdirname, pad_token="!", **kwargs) diff --git a/tests/models/speech_to_text/test_processing_speech_to_text.py b/tests/models/speech_to_text/test_processing_speech_to_text.py index 9ed008b834da..0116445c91b1 100644 --- a/tests/models/speech_to_text/test_processing_speech_to_text.py +++ b/tests/models/speech_to_text/test_processing_speech_to_text.py @@ -21,7 +21,6 @@ from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor, Speech2TextTokenizer from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_torch, require_torchaudio -from transformers.utils import FEATURE_EXTRACTOR_NAME from .test_feature_extraction_speech_to_text import floats_list @@ -55,7 +54,10 @@ def setUpClass(cls): "return_attention_mask": False, "do_normalize": True, } - save_json(feature_extractor_map, save_dir / FEATURE_EXTRACTOR_NAME) + feature_extractor = Speech2TextFeatureExtractor(**feature_extractor_map) + tokenizer = Speech2TextTokenizer.from_pretrained(cls.tmpdirname) + processor = Speech2TextProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + processor.save_pretrained(cls.tmpdirname) def get_tokenizer(self, **kwargs): return Speech2TextTokenizer.from_pretrained(self.tmpdirname, **kwargs) diff --git a/tests/models/speecht5/test_processing_speecht5.py b/tests/models/speecht5/test_processing_speecht5.py index a850c5fd9383..a6736132a390 100644 --- a/tests/models/speecht5/test_processing_speecht5.py +++ b/tests/models/speecht5/test_processing_speecht5.py @@ -13,8 +13,6 @@ # limitations under the License. """Tests for the SpeechT5 processors.""" -import json -import os import shutil import tempfile import unittest @@ -22,7 +20,6 @@ from transformers import is_speech_available, is_torch_available from transformers.models.speecht5 import SpeechT5Tokenizer from transformers.testing_utils import get_tests_dir, require_speech, require_torch -from transformers.utils import FEATURE_EXTRACTOR_NAME if is_speech_available() and is_torch_available(): @@ -60,9 +57,10 @@ def setUpClass(cls): "return_attention_mask": True, } - cls.feature_extraction_file = os.path.join(cls.tmpdirname, FEATURE_EXTRACTOR_NAME) - with open(cls.feature_extraction_file, "w", encoding="utf-8") as fp: - fp.write(json.dumps(feature_extractor_map) + "\n") + feature_extractor = SpeechT5FeatureExtractor(**feature_extractor_map) + tokenizer = SpeechT5Tokenizer.from_pretrained(cls.tmpdirname) + processor = SpeechT5Processor(tokenizer=tokenizer, feature_extractor=feature_extractor) + processor.save_pretrained(cls.tmpdirname) def get_tokenizer(self, **kwargs): return SpeechT5Tokenizer.from_pretrained(self.tmpdirname, **kwargs) diff --git a/tests/models/vision_text_dual_encoder/test_processing_vision_text_dual_encoder.py b/tests/models/vision_text_dual_encoder/test_processing_vision_text_dual_encoder.py index 2e0c21e63342..ef9699ff4f28 100644 --- a/tests/models/vision_text_dual_encoder/test_processing_vision_text_dual_encoder.py +++ b/tests/models/vision_text_dual_encoder/test_processing_vision_text_dual_encoder.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os import shutil import tempfile @@ -21,7 +20,7 @@ from transformers import BertTokenizerFast from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES, BertTokenizer from transformers.testing_utils import require_tokenizers, require_vision -from transformers.utils import IMAGE_PROCESSOR_NAME, is_torchvision_available, is_vision_available +from transformers.utils import is_torchvision_available, is_vision_available from ...test_processing_common import ProcessorTesterMixin @@ -51,12 +50,8 @@ def setUpClass(cls): "image_mean": [0.5, 0.5, 0.5], "image_std": [0.5, 0.5, 0.5], } - cls.image_processor_file = os.path.join(cls.tmpdirname, IMAGE_PROCESSOR_NAME) - with open(cls.image_processor_file, "w", encoding="utf-8") as fp: - json.dump(image_processor_map, fp) - + image_processor = ViTImageProcessor(**image_processor_map) tokenizer = cls.get_tokenizer() - image_processor = cls.get_image_processor() processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, image_processor=image_processor) processor.save_pretrained(cls.tmpdirname) diff --git a/tests/models/wav2vec2/test_processing_wav2vec2.py b/tests/models/wav2vec2/test_processing_wav2vec2.py index dc9ae4136315..0ecf6d00e012 100644 --- a/tests/models/wav2vec2/test_processing_wav2vec2.py +++ b/tests/models/wav2vec2/test_processing_wav2vec2.py @@ -20,7 +20,6 @@ from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES -from transformers.utils import FEATURE_EXTRACTOR_NAME from ...test_processing_common import ProcessorTesterMixin from .test_feature_extraction_wav2vec2 import floats_list @@ -52,15 +51,13 @@ def setUpClass(cls): cls.tmpdirname = tempfile.mkdtemp() cls.vocab_file = os.path.join(cls.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) - cls.feature_extraction_file = os.path.join(cls.tmpdirname, FEATURE_EXTRACTOR_NAME) with open(cls.vocab_file, "w", encoding="utf-8") as fp: fp.write(json.dumps(vocab_tokens) + "\n") - - with open(cls.feature_extraction_file, "w", encoding="utf-8") as fp: - fp.write(json.dumps(feature_extractor_map) + "\n") - tokenizer = cls.get_tokenizer() - tokenizer.save_pretrained(cls.tmpdirname) + + feature_extractor = Wav2Vec2FeatureExtractor(**feature_extractor_map) + processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor) + processor.save_pretrained(cls.tmpdirname) @classmethod def get_tokenizer(cls, **kwargs_init): diff --git a/tests/models/wav2vec2_bert/test_processing_wav2vec2_bert.py b/tests/models/wav2vec2_bert/test_processing_wav2vec2_bert.py index d05d3e2d44e0..fbdb73fff4ca 100644 --- a/tests/models/wav2vec2_bert/test_processing_wav2vec2_bert.py +++ b/tests/models/wav2vec2_bert/test_processing_wav2vec2_bert.py @@ -22,7 +22,6 @@ from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES from transformers.models.wav2vec2_bert import Wav2Vec2BertProcessor -from transformers.utils import FEATURE_EXTRACTOR_NAME from ...test_processing_common import ProcessorTesterMixin from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list @@ -53,15 +52,13 @@ def setUpClass(cls): cls.tmpdirname = tempfile.mkdtemp() cls.vocab_file = os.path.join(cls.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) - cls.feature_extraction_file = os.path.join(cls.tmpdirname, FEATURE_EXTRACTOR_NAME) with open(cls.vocab_file, "w", encoding="utf-8") as fp: fp.write(json.dumps(vocab_tokens) + "\n") - - with open(cls.feature_extraction_file, "w", encoding="utf-8") as fp: - fp.write(json.dumps(feature_extractor_map) + "\n") - tokenizer = cls.get_tokenizer() - tokenizer.save_pretrained(cls.tmpdirname) + + feature_extractor = SeamlessM4TFeatureExtractor(**feature_extractor_map) + processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + processor.save_pretrained(cls.tmpdirname) @classmethod def get_tokenizer(cls, **kwargs_init): diff --git a/tests/models/wav2vec2_with_lm/test_processing_wav2vec2_with_lm.py b/tests/models/wav2vec2_with_lm/test_processing_wav2vec2_with_lm.py index 705fe30bba38..d23f84735f3d 100644 --- a/tests/models/wav2vec2_with_lm/test_processing_wav2vec2_with_lm.py +++ b/tests/models/wav2vec2_with_lm/test_processing_wav2vec2_with_lm.py @@ -29,7 +29,7 @@ from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES from transformers.testing_utils import require_pyctcdecode, require_torch, require_torchaudio, slow -from transformers.utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available, is_torch_available +from transformers.utils import is_pyctcdecode_available, is_torch_available from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list @@ -66,16 +66,18 @@ def setUp(self): self.tmpdirname = tempfile.mkdtemp() self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) - self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) with open(self.vocab_file, "w", encoding="utf-8") as fp: fp.write(json.dumps(vocab_tokens) + "\n") - with open(self.feature_extraction_file, "w", encoding="utf-8") as fp: - fp.write(json.dumps(feature_extractor_map) + "\n") - # load decoder from hub self.decoder_name = "hf-internal-testing/ngram-beam-search-decoder" + feature_extractor = Wav2Vec2FeatureExtractor(**feature_extractor_map) + processor = Wav2Vec2ProcessorWithLM( + tokenizer=self.get_tokenizer(), feature_extractor=feature_extractor, decoder=self.get_decoder() + ) + processor.save_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs_init): kwargs = self.add_kwargs_tokens_map.copy() kwargs.update(kwargs_init) diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 295ee03a769e..924efe2db6bf 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -222,8 +222,7 @@ def test_processor_from_and_save_pretrained_as_nested_dict(self): processor_first = self.get_processor() with tempfile.TemporaryDirectory() as tmpdirname: - # Save with `legacy_serialization=False` so that all attrbiutes are saved in one json file - saved_files = processor_first.save_pretrained(tmpdirname, legacy_serialization=False) + saved_files = processor_first.save_pretrained(tmpdirname) check_json_file_has_correct_format(saved_files[0]) # Load it back and check if loaded correctly diff --git a/time_eval.py b/time_eval.py new file mode 100644 index 000000000000..c4c79df3c7bc --- /dev/null +++ b/time_eval.py @@ -0,0 +1,10 @@ +from transformers import AutoConfig, LlavaConfig + + +remote_text_config = AutoConfig.from_pretrained("AI4Chem/ChemLLM-7B-Chat" trust_remote_code=True) +local_vision_config = AutoConfig.from_pretrained("google/siglip2-so400m-patch14-384") +config = LlavaConfig(text_config=remote_text_config, vision_config=local_vision_config, image_token_id=92544) +config.save_pretrained("local_llava") + + +config = LlavaConfig.from_pretrained("local_llava", trust_remote_code=True)