Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Dec 12, 2024
1 parent e5c45a6 commit 8373cb8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial, wraps
from multiprocessing import Process
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
from zipfile import is_zipfile

Expand Down Expand Up @@ -3889,11 +3889,11 @@ def from_pretrained(
**has_file_kwargs,
}
if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs):
Process(
Thread(
target=auto_conversion,
args=(pretrained_model_name_or_path,),
kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
name="Process-auto_conversion",
name="Thread-auto_conversion",
).start()
else:
# Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file.
Expand Down
19 changes: 17 additions & 2 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import subprocess
import sys
import tempfile
import threading
import time
import unittest
from collections import defaultdict
Expand Down Expand Up @@ -2311,12 +2312,26 @@ class RequestCounter:

def __enter__(self):
self._counter = defaultdict(int)
self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug)
self._thread_id = threading.get_ident()
self._extra_info = []

def patched_with_thread_info(func):
def wrap(*args, **kwargs):
self._extra_info.append(threading.get_ident())
return func(*args, **kwargs)

return wrap

self.patcher = patch.object(urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug))
self.mock = self.patcher.start()
return self

def __exit__(self, *args, **kwargs) -> None:
for call in self.mock.call_args_list:
assert len(self.mock.call_args_list) == len(self._extra_info)

for thread_id, call in zip(self._extra_info, self.mock.call_args_list):
if thread_id != self._thread_id:
continue
log = call.args[0] % call.args[1:]
for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
if method in log:
Expand Down

0 comments on commit 8373cb8

Please sign in to comment.