Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error running eval.py on checkpoint created using --pretrained_lm=gpt2. #71

Open
eihli opened this issue Jan 14, 2024 · 0 comments
Open
Assignees
Labels
bug Something isn't working

Comments

@eihli
Copy link
Contributor

eihli commented Jan 14, 2024

Here's the steps to recreate:

First, to show that it's specifically related to the --pretrained-lm argument, run this train/eval pair once without --pretrained-lm=gpt2 in the training arguments, then run the pair again with --pretrained-lm=gpt2 in the training arguments.

  1. Train
python -m pdb train.py \
    --training_steps=12 \
    --log_eval_freq=4 \
    --warmup_steps=1 \
    --batch_size=4 \
    --eval_episodes=1 \
    --activation_fn=gelu \
    --save_model \
    --save_mode=checkpoint \
    --text_prop=1.0 \
    --eval_text_log_examples \
    --text_datasets=wikitext-2-v1 \
    --text_datasets_paths=wikitext \
    --disable_cosine_decay
  1. Evaluate
python -m pdb eval.py \
    --model_path=./models/neko-gato-620082/checkpoint_12.pt \
    --text_datasets=wikitext-2-v1 \
    --text_datasets_paths=wikitext \
    --eval_episodes=1

When you run the above with the --pretrained_lm=gpt2 argument, you get the following error message:

RuntimeError: Error(s) in loading state_dict for GatoPolicy:
        Unexpected key(s) in state_dict: "transformer.h.8.ln_1.weight", "transformer.h.8.ln_1.bias", "transformer.h.8.attn.bias", "transformer.h.8.attn.masked_bias", "transformer.h.8.attn.c_attn.weight", "transformer.h.8.attn.c_attn.bias", "transformer.h.8.attn.c_proj.weight", "transformer.h.8.attn.c_proj.bias", "transformer.h.8.ln_2.weight", "transformer.h.8.ln_2.bias", "transformer.h.8.mlp.c_fc.weight", "transformer.h.8.mlp.c_fc.bias", "transformer.h.8.mlp.c_proj.weight", "transformer.h.8.mlp.c_proj.bias", "transformer.h.9.ln_1.weight", "transformer.h.9.ln_1.bias", "transformer.h.9.attn.bias", "transformer.h.9.attn.masked_bias", "transformer.h.9.attn.c_attn.weight", "transformer.h.9.attn.c_attn.bias", "transformer.h.9.attn.c_proj.weight", "transformer.h.9.attn.c_proj.bias", "transformer.h.9.ln_2.weight", "transformer.h.9.ln_2.bias", "transformer.h.9.mlp.c_fc.weight", "transformer.h.9.mlp.c_fc.bias", "transformer.h.9.mlp.c_proj.weight", "transformer.h.9.mlp.c_proj.bias", "transformer.h.10.ln_1.weight", "transformer.h.10.ln_1.bias", "transformer.h.10.attn.bias", "transformer.h.10.attn.masked_bias", "transformer.h.10.attn.c_attn.weight", "transformer.h.10.attn.c_attn.bias", "transformer.h.10.attn.c_proj.weight", "transformer.h.10.attn.c_proj.bias", "transformer.h.10.ln_2.weight", "transformer.h.10.ln_2.bias", "transformer.h.10.mlp.c_fc.weight", "transformer.h.10.mlp.c_fc.bias", "transformer.h.10.mlp.c_proj.weight", "transformer.h.10.mlp.c_proj.bias", "transformer.h.11.ln_1.weight", "transformer.h.11.ln_1.bias", "transformer.h.11.attn.bias", "transformer.h.11.attn.masked_bias", "transformer.h.11.attn.c_attn.weight", "transformer.h.11.attn.c_attn.bias", "transformer.h.11.attn.c_proj.weight", "transformer.h.11.attn.c_proj.bias", "transformer.h.11.ln_2.weight", "transformer.h.11.ln_2.bias", "transformer.h.11.mlp.c_fc.weight", "transformer.h.11.mlp.c_fc.bias", "transformer.h.11.mlp.c_proj.weight", "transformer.h.11.mlp.c_proj.bias".
        size mismatch for transformer.wte.weight: copying a param with shape torch.Size([50257, 768]) from checkpoint, the shape in current model is torch.Size([1, 768]).

Doing some archeology, we find at one point in time pretrained_lm was removed from the training args before evaluation.

commit 1640ce0d97f9801695c1b2241ad6c29608e5f1e9
Author: Daniel Lawson <[email protected]>
Date:   Wed Jun 28 12:39:12 2023 -0400

    added init
---
 eval.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/eval.py b/eval.py
index f463c4e..8eaa8c8 100644
--- a/eval.py
+++ b/eval.py
@@ -24,6 +24,8 @@ def main(args):
         args_path = args.args_path

     training_args = json.load(open(args_path, 'r'))
+    if 'pretrained_lm' in training_args:
+        del training_args['pretrained_lm']

     # update args with eval_args
     for k, v in args.items():

That change was then modified to only delete the --pretrained_lm argument if --lora was passed.

commit ec4a486afc069f02c572c18cf73199223e0f1a8a
Author: Daniel Lawson <[email protected]>
Date:   Mon Jul 3 01:47:41 2023 -0400

    fixed eval afer lora
---
 eval.py | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/eval.py b/eval.py
index 9be6c83..0c827eb 100644
--- a/eval.py
+++ b/eval.py
@@ -6,6 +6,8 @@ import time
 import numpy as np
 import torch

+from peft import LoraConfig, TaskType, get_peft_model
+
 from gato.utils.utils import DotDict
 from gato.policy.gato_policy import GatoPolicy
 from gato.envs.setup_env import load_envs
@@ -24,8 +26,8 @@ def main(args):
         args_path = args.args_path

     training_args = json.load(open(args_path, 'r'))
-    if 'pretrained_lm' in training_args:
-        del training_args['pretrained_lm']
+    if not ('lora' in training_args and training_args['lora']):
+        training_args['pretrained_lm'] = None

     # update args with eval_args
     for k, v in args.items():
@@ -72,7 +74,15 @@ def main(args):
         use_patch_pos_encoding=not eval_args.disable_patch_pos_encoding,
         use_pos_encoding=not eval_args.disable_inner_pos_encoding,
         activation_fn=eval_args.activation_fn,
+        pretrained_lm=eval_args.pretrained_lm,
+        flash=eval_args.flash
     )
+
+    if eval_args.get('lora', False):
+        assert eval_args.pretrained_lm is not None, 'Must specify pretrained LM for LORA'
+        peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=eval_args.lora_r, lora_alpha=eval_args.lora_alpha, lora_dropout=eval_args.lora_dropout)
+        model.transformer = get_peft_model(model.transformer, peft_config)
+
     model.load_state_dict(gato_checkpoint)
     model = model.to(eval_args.device)
     model.device = eval_args.device

The logic explicitly raises if lora_, and not pretrained_lm.

The logic implicitly fails if pretrained_lm and not lora.

I'm not sure if that's purposeful or accidental. I don't know much about lora and how it interacts with the model. I'm guessing it's fine to run pretrained_lm without lora.

Just logging this right now as research/investigation notes to pick back up later.

@eihli eihli changed the title Error running eval.py on checkpoing created using pretrained_lm. Error running eval.py on checkpoint created using --pretrained_lm=gpt2. Jan 14, 2024
@eihli eihli self-assigned this Jan 14, 2024
@eihli eihli added the bug Something isn't working label Jan 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: No status
Development

No branches or pull requests

1 participant