diff --git a/tests/test_config.py b/tests/test_config.py index 9f7d85e39ad67..225d71c0bc0ea 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -64,8 +64,9 @@ def test_get_sliding_window(): def test_rope_customization(): - TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0} + TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0} TEST_ROPE_THETA = 16_000_000.0 + LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0} llama_model_config = ModelConfig( "meta-llama/Meta-Llama-3-8B-Instruct", @@ -95,29 +96,29 @@ def test_rope_customization(): None) == TEST_ROPE_THETA assert llama_model_config.max_model_len == 16384 - # TODO: add these back when the rope configs are fixed - # LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0} - # longchat_model_config = ModelConfig( - # "lmsys/longchat-13b-16k", - # "lmsys/longchat-13b-16k", - # tokenizer_mode="auto", - # trust_remote_code=False, - # dtype="float16", - # seed=0, - # ) - # assert getattr(longchat_model_config.hf_config, "rope_scaling", - # None) == LONGCHAT_ROPE_SCALING - # assert longchat_model_config.max_model_len == 16384 - - # longchat_model_config = ModelConfig( - # "lmsys/longchat-13b-16k", - # "lmsys/longchat-13b-16k", - # tokenizer_mode="auto", - # trust_remote_code=False, - # dtype="float16", - # seed=0, - # rope_scaling=TEST_ROPE_SCALING, - # ) - # assert getattr(longchat_model_config.hf_config, "rope_scaling", - # None) == TEST_ROPE_SCALING - # assert longchat_model_config.max_model_len == 4096 + longchat_model_config = ModelConfig( + "lmsys/longchat-13b-16k", + "lmsys/longchat-13b-16k", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + # Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config + assert all( + longchat_model_config.hf_config.rope_scaling.get(key) == value + for key, value in LONGCHAT_ROPE_SCALING.items()) + assert longchat_model_config.max_model_len == 16384 + + longchat_model_config = ModelConfig( + "lmsys/longchat-13b-16k", + "lmsys/longchat-13b-16k", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + rope_scaling=TEST_ROPE_SCALING, + ) + assert getattr(longchat_model_config.hf_config, "rope_scaling", + None) == TEST_ROPE_SCALING + assert longchat_model_config.max_model_len == 4096