From 8373cb8a62f9d25c9587548f9132180fa79a6813 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 12 Dec 2024 14:42:10 +0100 Subject: [PATCH] fix --- src/transformers/modeling_utils.py | 6 +++--- src/transformers/testing_utils.py | 19 +++++++++++++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index dae29111c8dcc0..5341a5832e8ab7 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -29,7 +29,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps -from multiprocessing import Process +from threading import Thread from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union from zipfile import is_zipfile @@ -3889,11 +3889,11 @@ def from_pretrained( **has_file_kwargs, } if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): - Process( + Thread( target=auto_conversion, args=(pretrained_model_name_or_path,), kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs}, - name="Process-auto_conversion", + name="Thread-auto_conversion", ).start() else: # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 30f7b5a68fb2c0..c8ecf058ecf11e 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -28,6 +28,7 @@ import subprocess import sys import tempfile +import threading import time import unittest from collections import defaultdict @@ -2311,12 +2312,26 @@ class RequestCounter: def __enter__(self): self._counter = defaultdict(int) - self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug) + self._thread_id = threading.get_ident() + self._extra_info = [] + + def patched_with_thread_info(func): + def wrap(*args, **kwargs): + self._extra_info.append(threading.get_ident()) + return func(*args, **kwargs) + + return wrap + + self.patcher = patch.object(urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug)) self.mock = self.patcher.start() return self def __exit__(self, *args, **kwargs) -> None: - for call in self.mock.call_args_list: + assert len(self.mock.call_args_list) == len(self._extra_info) + + for thread_id, call in zip(self._extra_info, self.mock.call_args_list): + if thread_id != self._thread_id: + continue log = call.args[0] % call.args[1:] for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"): if method in log: