Skip to content

Commit

Permalink
Make architecture compatibility check non-fatal if base model config …
Browse files Browse the repository at this point in the history
…cannot be loaded
  • Loading branch information
tgaddair committed Mar 10, 2024
1 parent eebea85 commit d711299
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions server/lorax_server/utils/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,28 @@ def _load_and_merge(
return module_map, adapter_config, merged_weight_names, tokenizer


def check_architectures(model_id: str, adapter_id: str, adapter_config: LoraConfig, api_token: str):
try:
expected_config = AutoConfig.from_pretrained(model_id, token=api_token)
model_config = AutoConfig.from_pretrained(adapter_config.base_model_name_or_path, token=api_token)
except Exception as e:
warnings.warn(
f"Unable to check architecture compatibility for adapter '{adapter_id}' "
f"against model '{model_id}'. Assuming they are compatible. Error: {e}")
return

if model_config.architectures == expected_config.architectures:
warnings.warn(
f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
)
else:
# TODO(travis): revisit this when we support clasification heads which will not use CausalLM
raise ValueError(f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.")


@lru_cache(maxsize=128)
def load_module_map(
model_id: str,
Expand All @@ -106,18 +128,7 @@ def load_module_map(
config_path = get_config_path(adapter_id, adapter_source)
adapter_config = LoraConfig.from_pretrained(config_path, token=api_token)
if adapter_config.base_model_name_or_path != model_id:
expected_config = AutoConfig.from_pretrained(model_id)
model_config = AutoConfig.from_pretrained(adapter_config.base_model_name_or_path)
if model_config.architectures == expected_config.architectures:
warnings.warn(
f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
)
else:
# TODO(travis): revisit this when we support clasification heads which will not use CausalLM
raise ValueError(f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.")
check_architectures(model_id, adapter_id, adapter_config, api_token)

try:
adapter_tokenizer = AutoTokenizer.from_pretrained(config_path, token=api_token)
Expand Down

0 comments on commit d711299

Please sign in to comment.