diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 830564a57..f2d38cc31 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -1187,7 +1187,11 @@ def save( # For other cases, we want to add the class name: elif not class_ref.startswith("sentence_transformers."): class_ref = f"{class_ref}.{type(module).__name__}" - modules_config.append({"idx": idx, "name": name, "path": os.path.basename(model_path), "type": class_ref}) + + module_config = {"idx": idx, "name": name, "path": os.path.basename(model_path), "type": class_ref} + if self.module_kwargs and name in self.module_kwargs and (module_kwargs := self.module_kwargs[name]): + module_config["kwargs"] = module_kwargs + modules_config.append(module_config) with open(os.path.join(path, "modules.json"), "w") as fOut: json.dump(modules_config, fOut, indent=2) @@ -1550,7 +1554,7 @@ def _load_module_class_from_ref( if class_ref.startswith("sentence_transformers."): return import_from_string(class_ref) - if trust_remote_code: + if trust_remote_code or os.path.exists(model_name_or_path): code_revision = model_kwargs.pop("code_revision", None) if model_kwargs else None try: return get_class_from_dynamic_module(