From a92654fba9a10b8e9d6da25f9d80e7b7116fc9e8 Mon Sep 17 00:00:00 2001 From: Richard Kronick <41272573+richardkronick@users.noreply.github.com> Date: Thu, 6 Jun 2024 11:40:49 -0400 Subject: [PATCH] Unit tests loading from pretrained fill missing keys (#623) * Add unit tests for fill_missing_keys * Reformat test_loading_from_pretrained.py with black * Rename unit test file to test_loading_from_pretrained_utilities to avoid naming conflict --- .../test_loading_from_pretrained_utilities.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/unit/test_loading_from_pretrained_utilities.py diff --git a/tests/unit/test_loading_from_pretrained_utilities.py b/tests/unit/test_loading_from_pretrained_utilities.py new file mode 100644 index 000000000..de40e4314 --- /dev/null +++ b/tests/unit/test_loading_from_pretrained_utilities.py @@ -0,0 +1,72 @@ +from unittest import mock + +import pytest + +from transformer_lens import HookedTransformer +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.loading_from_pretrained import fill_missing_keys + + +def get_default_config(): + return HookedTransformerConfig( + d_model=128, d_head=8, n_heads=16, n_ctx=128, n_layers=1, d_vocab=50257, attn_only=True + ) + + +# Successes + + +@mock.patch("logging.warning") +def test_fill_missing_keys(mock_warning): + cfg = get_default_config() + model = HookedTransformer(cfg) + default_state_dict = model.state_dict() + + incomplete_state_dict = {k: v for k, v in default_state_dict.items() if "W_" not in k} + + filled_state_dict = fill_missing_keys(model, incomplete_state_dict) + + assert set(filled_state_dict.keys()) == set(default_state_dict.keys()) + + # Check that warnings were issued for missing weight matrices + for key in default_state_dict: + if "W_" in key and key not in incomplete_state_dict: + mock_warning.assert_any_call( + f"Missing key for a weight matrix in pretrained, filled in with an empty tensor: {key}" + ) + + +def test_fill_missing_keys_with_hf_model_keys(): + cfg = get_default_config() + model = HookedTransformer(cfg) + default_state_dict = model.state_dict() + + incomplete_state_dict = {k: v for k, v in default_state_dict.items() if "hf_model" not in k} + + filled_state_dict = fill_missing_keys(model, incomplete_state_dict) + + expected_keys = set(default_state_dict.keys()) - { + k for k in default_state_dict.keys() if "hf_model" in k + } + assert set(filled_state_dict.keys()) == expected_keys + + +def test_fill_missing_keys_no_missing_keys(): + cfg = get_default_config() + model = HookedTransformer(cfg) + default_state_dict = model.state_dict() + + filled_state_dict = fill_missing_keys(model, default_state_dict) + + assert filled_state_dict == default_state_dict + + +# Failures + + +def test_fill_missing_keys_raises_error_on_invalid_model(): + invalid_model = None + default_state_dict = {} + + with pytest.raises(AttributeError): + fill_missing_keys(invalid_model, default_state_dict)