Skip to content

Commit

Permalink
Pass model and tokenizer arguments to Transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
aphedges committed Nov 15, 2023
1 parent 12d8f38 commit 94247f2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
24 changes: 17 additions & 7 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,19 @@ def __init__(self, model_name_or_path: Optional[str] = None,
modules: Optional[Iterable[nn.Module]] = None,
device: Optional[str] = None,
cache_folder: Optional[str] = None,
use_auth_token: Union[bool, str, None] = None
use_auth_token: Union[bool, str, None] = None,
model_args: Optional[Dict] = None,
tokenizer_args: Optional[Dict] = None,
):
self._model_card_vars = {}
self._model_card_text = None
self._model_config = {}

if model_args is None:
model_args = {}
if tokenizer_args is None:
tokenizer_args = {}

if cache_folder is None:
cache_folder = os.getenv('SENTENCE_TRANSFORMERS_HOME')
if cache_folder is None:
Expand Down Expand Up @@ -92,9 +99,9 @@ def __init__(self, model_name_or_path: Optional[str] = None,
use_auth_token=use_auth_token)

if os.path.exists(os.path.join(model_path, 'modules.json')): #Load as SentenceTransformer model
modules = self._load_sbert_model(model_path)
modules = self._load_sbert_model(model_path, model_args, tokenizer_args)
else: #Load with AutoModel
modules = self._load_auto_model(model_path)
modules = self._load_auto_model(model_path, model_args, tokenizer_args)

if modules is not None and not isinstance(modules, OrderedDict):
modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])
Expand Down Expand Up @@ -800,16 +807,16 @@ def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step):
shutil.rmtree(old_checkpoints[0]['path'])


def _load_auto_model(self, model_name_or_path):
def _load_auto_model(self, model_name_or_path, model_args, tokenizer_args):
"""
Creates a simple Transformer + Mean Pooling model and returns the modules
"""
logger.warning("No sentence-transformers model found with name {}. Creating a new one with MEAN pooling.".format(model_name_or_path))
transformer_model = Transformer(model_name_or_path)
transformer_model = Transformer(model_name_or_path, model_args=model_args, tokenizer_args=tokenizer_args)
pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), 'mean')
return [transformer_model, pooling_model]

def _load_sbert_model(self, model_path):
def _load_sbert_model(self, model_path, model_args, tokenizer_args):
"""
Loads a full sentence-transformers model
"""
Expand Down Expand Up @@ -839,7 +846,10 @@ def _load_sbert_model(self, model_path):
modules = OrderedDict()
for module_config in modules_config:
module_class = import_from_string(module_config['type'])
module = module_class.load(os.path.join(model_path, module_config['path']))
if module_class == Transformer:
module = module_class.load(os.path.join(model_path, module_config['path']), model_args, tokenizer_args)
else:
module = module_class.load(os.path.join(model_path, module_config['path']))
modules[module_config['name']] = module

return modules
Expand Down
7 changes: 6 additions & 1 deletion sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def save(self, output_path: str):
json.dump(self.get_config_dict(), fOut, indent=2)

@staticmethod
def load(input_path: str):
def load(input_path: str, model_args: Optional[Dict] = None, tokenizer_args: Optional[Dict] = None):
#Old classes used other config names than 'sentence_bert_config.json'
for config_name in ['sentence_bert_config.json', 'sentence_roberta_config.json', 'sentence_distilbert_config.json', 'sentence_camembert_config.json', 'sentence_albert_config.json', 'sentence_xlm-roberta_config.json', 'sentence_xlnet_config.json']:
sbert_config_path = os.path.join(input_path, config_name)
Expand All @@ -147,6 +147,11 @@ def load(input_path: str):

with open(sbert_config_path) as fIn:
config = json.load(fIn)
# Override stored model and tokenizer arguments if specified
if model_args:
config["model_args"] = {**config.get("model_args", {}), **model_args}
if tokenizer_args:
config["tokenizer_args"] = {**config.get("tokenizer_args", {}), **tokenizer_args}
return Transformer(model_name_or_path=input_path, **config)


Expand Down

0 comments on commit 94247f2

Please sign in to comment.