|
7 | 7 | from typing import Any, Dict, Optional, Type, Union
|
8 | 8 |
|
9 | 9 | import huggingface_hub
|
10 |
| -from huggingface_hub import (file_exists, hf_hub_download, |
| 10 | +from huggingface_hub import (file_exists, hf_hub_download, list_repo_files, |
11 | 11 | try_to_load_from_cache)
|
12 | 12 | from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
13 | 13 | LocalEntryNotFoundError,
|
@@ -395,18 +395,28 @@ def get_sentence_transformer_tokenizer_config(model: str,
|
395 | 395 | - dict: A dictionary containing the configuration parameters
|
396 | 396 | for the Sentence Transformer BERT model.
|
397 | 397 | """
|
398 |
| - for config_name in [ |
399 |
| - "sentence_bert_config.json", |
400 |
| - "sentence_roberta_config.json", |
401 |
| - "sentence_distilbert_config.json", |
402 |
| - "sentence_camembert_config.json", |
403 |
| - "sentence_albert_config.json", |
404 |
| - "sentence_xlm-roberta_config.json", |
405 |
| - "sentence_xlnet_config.json", |
406 |
| - ]: |
407 |
| - encoder_dict = get_hf_file_to_dict(config_name, model, revision) |
408 |
| - if encoder_dict: |
409 |
| - break |
| 398 | + sentence_transformer_config_files = [ |
| 399 | + "sentence_bert_config.json", |
| 400 | + "sentence_roberta_config.json", |
| 401 | + "sentence_distilbert_config.json", |
| 402 | + "sentence_camembert_config.json", |
| 403 | + "sentence_albert_config.json", |
| 404 | + "sentence_xlm-roberta_config.json", |
| 405 | + "sentence_xlnet_config.json", |
| 406 | + ] |
| 407 | + try: |
| 408 | + # If model is on HuggingfaceHub, get the repo files |
| 409 | + repo_files = list_repo_files(model, revision=revision, token=HF_TOKEN) |
| 410 | + except Exception as e: |
| 411 | + logger.debug("Error getting repo files", e) |
| 412 | + repo_files = [] |
| 413 | + |
| 414 | + encoder_dict = None |
| 415 | + for config_name in sentence_transformer_config_files: |
| 416 | + if config_name in repo_files or Path(model).exists(): |
| 417 | + encoder_dict = get_hf_file_to_dict(config_name, model, revision) |
| 418 | + if encoder_dict: |
| 419 | + break |
410 | 420 |
|
411 | 421 | if not encoder_dict:
|
412 | 422 | return None
|
|
0 commit comments