Skip to content

Commit

Permalink
Change back to Thread for SF conversion (#35236)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

---------

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Dec 12, 2024
1 parent e3ee49f commit a691ccb
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial, wraps
from multiprocessing import Process
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
from zipfile import is_zipfile

Expand Down Expand Up @@ -3825,11 +3825,11 @@ def from_pretrained(
**has_file_kwargs,
}
if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs):
Process(
Thread(
target=auto_conversion,
args=(pretrained_model_name_or_path,),
kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
name="Process-auto_conversion",
name="Thread-auto_conversion",
).start()
else:
# Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file.
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/safetensors_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
# security breaches.
pr = previous_pr(api, model_id, pr_title, token=token)

if pr is None or (not private and pr.author != "SFConvertBot"):
if pr is None or (not private and pr.author != "SFconvertbot"):
spawn_conversion(token, private, model_id)
pr = previous_pr(api, model_id, pr_title, token=token)
else:
Expand Down
21 changes: 19 additions & 2 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import subprocess
import sys
import tempfile
import threading
import time
import unittest
from collections import defaultdict
Expand Down Expand Up @@ -2311,12 +2312,28 @@ class RequestCounter:

def __enter__(self):
self._counter = defaultdict(int)
self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug)
self._thread_id = threading.get_ident()
self._extra_info = []

def patched_with_thread_info(func):
def wrap(*args, **kwargs):
self._extra_info.append(threading.get_ident())
return func(*args, **kwargs)

return wrap

self.patcher = patch.object(
urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug)
)
self.mock = self.patcher.start()
return self

def __exit__(self, *args, **kwargs) -> None:
for call in self.mock.call_args_list:
assert len(self.mock.call_args_list) == len(self._extra_info)

for thread_id, call in zip(self._extra_info, self.mock.call_args_list):
if thread_id != self._thread_id:
continue
log = call.args[0] % call.args[1:]
for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
if method in log:
Expand Down

0 comments on commit a691ccb

Please sign in to comment.