diff --git a/model-trainer-huggingface/src/test_utils.py b/model-trainer-huggingface/src/test_utils.py index 74c4f3f..4881c16 100644 --- a/model-trainer-huggingface/src/test_utils.py +++ b/model-trainer-huggingface/src/test_utils.py @@ -4,7 +4,7 @@ def test_parse_training_args_int_float(): - params = {"num_train_epochs": "1"} + params = {"num_train_epochs": "1", "target_modules": "q,v"} assert parse_training_args(params).num_train_epochs == 1.0 params = {"num_train_epochs": "1", "max_steps": "5"} diff --git a/model-trainer-huggingface/src/train.ipynb b/model-trainer-huggingface/src/train.ipynb index e157ad0..f6299c9 100644 --- a/model-trainer-huggingface/src/train.ipynb +++ b/model-trainer-huggingface/src/train.ipynb @@ -206,12 +206,13 @@ "lora_config2 = LoraConfig(\n", " r=16,\n", " lora_alpha=32,\n", - " # target modules should be unset so it can detect target_modules automatically\n", - " # target_modules=[\"query_key_value\"],\n", " lora_dropout=0.05,\n", " bias=\"none\",\n", " task_type=\"CAUSAL_LM\"\n", ")\n", + "target_modules = params.get(\"target_modules\")\n", + "if target_modules:\n", + " lora_config2.target_modules = [mod.strip() for mod in target_modules.split(\",\")]\n", "\n", "model = prepare_model_for_kbit_training(model)\n", "\n",