Skip to content

Commit

Permalink
only download safetensors when available (#25)
Browse files Browse the repository at this point in the history
Fixes #24

* also adds more threads when there are more files
  • Loading branch information
samos123 committed Aug 5, 2023
1 parent 5e9d69c commit 497c160
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 7 deletions.
2 changes: 1 addition & 1 deletion model-loader-huggingface/src/load.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
" return destination\n",
"\n",
"processes = []\n",
"with ThreadPoolExecutor(max_workers=10) as executor:\n",
"with ThreadPoolExecutor(max_workers=len(filenames)) as executor:\n",
" for filename in filenames:\n",
" processes.append(executor.submit(download_file, filename))\n",
"\n",
Expand Down
21 changes: 19 additions & 2 deletions model-loader-huggingface/src/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,28 @@
],
[
"config.json",
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
"model.safetensors.index.json",
"pytorch_model-00001-of-00002.bin",
"pytorch_model-00002-of-00002.bin",
],
),
(
[
"flax_model-00001-of-00005.msgpack",
"flax_model.msgpack.index.json",
"model-00001-of-00005.safetensors",
"model-00002-of-00005.safetensors",
"model.safetensors.index.json",
"pytorch_model-00001-of-00005.bin",
"pytorch_model.bin.index.json",
"tf_model-00001-of-00005.h5",
"tf_model.h5.index.json"
],
[
"model-00001-of-00005.safetensors",
"model-00002-of-00005.safetensors",
"model.safetensors.index.json",
]
),
],
)
Expand Down
26 changes: 22 additions & 4 deletions model-loader-huggingface/src/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
from typing import List

def filter_pytorch_model(files: List[str]) -> List[str]:
return list(filter(lambda f: not f.startswith("pytorch_model") and not f == "pytorch_model.bin.index.json", files))

def filter_tensorflow_model(files: List[str]) -> List[str]:
return list(filter(lambda f: not f.startswith("tf_model") and not f == "tf_model.h5.index.json", files))

def filter_flax_model(files: List[str]) -> List[str]:
return list(filter(lambda f: not f.startswith("flax_model") and not f == "flax_model.msgpack.index.json", files))


def filter_files(files: List[str]) -> List[str]:
files = list(filter(lambda f: not f.startswith("coreml/"), files))
has_pytorch_model = any([f.startswith("pytorch_model") for f in files])
has_tensorflow_model = any([f.startswith("tf_model") for f in files])
has_safetensors = any([f.endswith(".safetensors") for f in files])
if has_safetensors:
files = filter_pytorch_model(files)
files = filter_tensorflow_model(files)
files = filter_flax_model(files)
return files
if has_pytorch_model:
# filter out safetensors
files = list(filter(lambda f: not f.endswith(".safetensors"), files))
# filter out tensorflow model
files = list(filter(lambda f: not f.startswith("tf_model"), files))
files = filter_tensorflow_model(files)
files = filter_flax_model(files)
return files
if has_tensorflow_model:
files = filter_flax_model(files)
return files
return files

0 comments on commit 497c160

Please sign in to comment.