diff --git a/model-loader-huggingface/src/load.ipynb b/model-loader-huggingface/src/load.ipynb index 3646b9a..971b096 100644 --- a/model-loader-huggingface/src/load.ipynb +++ b/model-loader-huggingface/src/load.ipynb @@ -161,7 +161,7 @@ " return destination\n", "\n", "processes = []\n", - "with ThreadPoolExecutor(max_workers=10) as executor:\n", + "with ThreadPoolExecutor(max_workers=len(filenames)) as executor:\n", " for filename in filenames:\n", " processes.append(executor.submit(download_file, filename))\n", "\n", diff --git a/model-loader-huggingface/src/test_utils.py b/model-loader-huggingface/src/test_utils.py index 91f47b7..7128ec1 100644 --- a/model-loader-huggingface/src/test_utils.py +++ b/model-loader-huggingface/src/test_utils.py @@ -25,11 +25,28 @@ ], [ "config.json", + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", "model.safetensors.index.json", - "pytorch_model-00001-of-00002.bin", - "pytorch_model-00002-of-00002.bin", + ], + ), + ( + [ + "flax_model-00001-of-00005.msgpack", + "flax_model.msgpack.index.json", + "model-00001-of-00005.safetensors", + "model-00002-of-00005.safetensors", + "model.safetensors.index.json", + "pytorch_model-00001-of-00005.bin", "pytorch_model.bin.index.json", + "tf_model-00001-of-00005.h5", + "tf_model.h5.index.json" ], + [ + "model-00001-of-00005.safetensors", + "model-00002-of-00005.safetensors", + "model.safetensors.index.json", + ] ), ], ) diff --git a/model-loader-huggingface/src/utils.py b/model-loader-huggingface/src/utils.py index 065fdd7..169611e 100644 --- a/model-loader-huggingface/src/utils.py +++ b/model-loader-huggingface/src/utils.py @@ -1,13 +1,31 @@ from typing import List +def filter_pytorch_model(files: List[str]) -> List[str]: + return list(filter(lambda f: not f.startswith("pytorch_model") and not f == "pytorch_model.bin.index.json", files)) + +def filter_tensorflow_model(files: List[str]) -> List[str]: + return list(filter(lambda f: not f.startswith("tf_model") and not f == "tf_model.h5.index.json", files)) + +def filter_flax_model(files: List[str]) -> List[str]: + return list(filter(lambda f: not f.startswith("flax_model") and not f == "flax_model.msgpack.index.json", files)) + def filter_files(files: List[str]) -> List[str]: files = list(filter(lambda f: not f.startswith("coreml/"), files)) has_pytorch_model = any([f.startswith("pytorch_model") for f in files]) + has_tensorflow_model = any([f.startswith("tf_model") for f in files]) + has_safetensors = any([f.endswith(".safetensors") for f in files]) + if has_safetensors: + files = filter_pytorch_model(files) + files = filter_tensorflow_model(files) + files = filter_flax_model(files) + return files if has_pytorch_model: - # filter out safetensors - files = list(filter(lambda f: not f.endswith(".safetensors"), files)) - # filter out tensorflow model - files = list(filter(lambda f: not f.startswith("tf_model"), files)) + files = filter_tensorflow_model(files) + files = filter_flax_model(files) + return files + if has_tensorflow_model: + files = filter_flax_model(files) + return files return files