Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proposal] Ensure TransformerLens does not load from hugging face when config is passed in #754

Open
1 task done
hamind opened this issue Oct 11, 2024 · 2 comments
Open
1 task done
Labels
complexity-moderate Moderately complicated issues for people who have intermediate experience with the code

Comments

@hamind
Copy link

hamind commented Oct 11, 2024

Proposal

Change some code that could load model locally.

Motivation

Today I want to load gpt2 model that download from huggingface website locally like Llama, but it keeps try to conncetting huggingface to download.
Then I check the code and find that

  1. No loadable local model approach
  2. If huggingface model already exists, there is no need to download model config from huggingface and could direct get it fron huggingface model.

Pitch

For model downloaded from huggingface or not cache, providing a approach to load model locally.

Alternatives

  1. For function "HookedTransformer.from_pretrained", could consider to add parameters to pass local model address.
  2. If huggingface model already exists, get the config from huggingface model directly.

Checklist

  • I have checked that there is no similar issue in the repo (required)
@hamind hamind closed this as completed Oct 11, 2024
@hamind hamind reopened this Oct 11, 2024
@bryce13950
Copy link
Collaborator

Could you share the code you are using to load TransformerLens? You should be able to pass in your local version of the model with the param hf_model

@bryce13950 bryce13950 added the needs-information More information is needed from the issue creator before moving forward. label Oct 15, 2024
@hamind
Copy link
Author

hamind commented Oct 16, 2024

I've modified less code, so I've just pasted the relevant code directly here. I've labeled the python file location and line number of the code, as well as the original version of the code which I've represented as a comment, with the new code shown below the old code for your convenience in checking.

 In transformer_lens.HookedTransformer.py line 1257
 
 cfg = loading.get_pretrained_model_config(
            official_model_name,
            # hf_cfg=hf_cfg
            hf_cfg=hf_model.config,
            checkpoint_index=checkpoint_index,
            checkpoint_value=checkpoint_value,
            fold_ln=fold_ln,
            device=device,
            n_devices=n_devices,
            default_prepend_bos=default_prepend_bos,
            dtype=dtype,
            first_n_layers=first_n_layers,
            **from_pretrained_kwargs,
        )
     
In transformer_lens.loading_from_pretrained.py line 1583

# if hf_cfg is not None:
#     cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config"# {}).get("load_in_4bit", False)

if hf_cfg is not None:
    cfg_dict["load_in_4bit"] = hf_cfg.to_dict().get("quantization_config"{}).get("load_in_4bit", False)

     
In transformer_lens.loading_from_pretrained.py line 708

# def convert_hf_model_config(model_name: str, **kwargs):
def convert_hf_model_config(model_name: str, hf_config = None, **kwargs):
    """
    Returns the model config for a HuggingFace model, converted to a dictionary
    in the HookedTransformerConfig format.

    Takes the official_model_name as an input.
    """
    if (Path(model_name) / "config.json").exists():
        logging.info("Loading model config from local directory")
        official_model_name = model_name
    else:
        official_model_name = get_official_model_name(model_name)

    # Load HuggingFace model config
    if "llama" in official_model_name.lower():
        architecture = "LlamaForCausalLM"
    elif "gemma-2" in official_model_name.lower():
        architecture = "Gemma2ForCausalLM"
    elif "gemma" in official_model_name.lower():
        architecture = "GemmaForCausalLM"
    else:
        # huggingface_token = os.environ.get("HF_TOKEN", None)
        # hf_config = AutoConfig.from_pretrained(
        #     official_model_name,
        #     token=huggingface_token,
        #     **kwargs,
        # )
        if hf_config is None:
            huggingface_token = os.environ.get("HF_TOKEN", None)
            hf_config = AutoConfig.from_pretrained(
                official_model_name,
                token=huggingface_token,
                **kwargs,
            )
        architecture = hf_config.architectures[0]
    ...
     
In transformer_lens.loading_from_pretrained.py line 1525 and line 1543

    if Path(model_name).exists():
        # If the model_name is a path, it's a local model
        # cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
        cfg_dict = convert_hf_model_config(model_name, hf_cfg, **kwargs)
        official_model_name = model_name
    else:
        official_model_name = get_official_model_name(model_name)
    if (
        official_model_name.startswith("NeelNanda")
        or official_model_name.startswith("ArthurConmy")
        or official_model_name.startswith("Baidicoot")
    ):
        cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
    else:
        if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
            "trust_remote_code", False
        ):
            logging.warning(
                f"Loading model {official_model_name} requires setting trust_remote_code=True"
            )
            kwargs["trust_remote_code"] = True
        # cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
        cfg_dict = convert_hf_model_config(official_model_name, hf_cfg, **kwargs)
     

@bryce13950 bryce13950 added complexity-moderate Moderately complicated issues for people who have intermediate experience with the code and removed needs-information More information is needed from the issue creator before moving forward. labels Nov 3, 2024
@bryce13950 bryce13950 changed the title [Proposal] Add function [Proposal] Ensure TransformerLens does not load from hugging face when config is passed in Nov 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
complexity-moderate Moderately complicated issues for people who have intermediate experience with the code
Projects
None yet
Development

No branches or pull requests

2 participants