Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unsloth overwrites the forward call function of a model loaded by huggingface library #1713

Open
DecoderLiu opened this issue Feb 14, 2025 · 2 comments

Comments

@DecoderLiu
Copy link

I am trying out the GRPO notebook with a pretrained model as my reward model. Basically, I followed the notebook from this link https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb . I load the base model as the notebook do

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

Then I load my reward model using the huggingface library:

# Load reward tokenizer and model
reward_model_path = 'FreedomIntelligence/medical_o1_verifier_3B'
reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_path)
reward_model = AutoModelForSequenceClassification.from_pretrained(
    reward_model_path, torch_dtype="auto", device_map="auto", attn_implementation="flash_attention_2", num_labels=2
)

reward_template = """<Model Response>
{}
</Model Response>

<Reference Answer>
{}
</Reference Answer>

Your task is to evaluate the model response by comparing it to the reference answer. If the model response is correct and aligns with the reference answer, output "True" . If it is incorrect or fails to select the correct option (if options are provided), output "False" . {}"""


def medical_verifier(prompts, completions, answer, **kwargs) -> list[float]:
    responses = completions
    if answer is None:
        return [0.0]*len(responses)
    rewards = []
    for resp, ref in zip(responses, answer):
        text = reward_template.format(resp, ref, reward_tokenizer.eos_token)
        input_batch = reward_tokenizer([text], return_tensors="pt").to(reward_model.device)
        with torch.no_grad():
            logits = reward_model(**input_batch,return_dict=True).logits
            probabilities = F.softmax(logits, dim=-1)

        reward = 2.0 if probabilities[0,1] > 0.5 else 0.0
        rewards.append(reward)
    
    return rewards

Then I execute

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        medical_verifier
    ],
    args = training_args,
    train_dataset = train_split,
)
trainer.train()

Then I got the following error:

AttributeError                            Traceback (most recent call last)
Cell In[7], line 10
      1 trainer = GRPOTrainer(
      2     model = model,
      3     processing_class = tokenizer,
   (...)
      8     train_dataset = train_split,
      9 )
---> 10 trainer.train()

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/transformers/trainer.py:2171](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/transformers/trainer.py:2171), in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2169         hf_hub_utils.enable_progress_bars()
   2170 else:
-> 2171     return inner_training_loop(
   2172         args=args,
   2173         resume_from_checkpoint=resume_from_checkpoint,
   2174         trial=trial,
   2175         ignore_keys_for_eval=ignore_keys_for_eval,
   2176     )

File :382, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File :25, in _unsloth_training_step(self, model, inputs, num_items_in_batch)

File [/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/unsloth_compiled_cache/GRPOTrainer.py:410](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/unsloth_compiled_cache/GRPOTrainer.py:410), in UnslothGRPOTrainer._prepare_inputs(self, inputs)
    407             for example in inputs:
    408                 # Repeat each value in the column for `num_generations` times
    409                 reward_kwargs[key].extend([example[key]] * self.num_generations)
--> 410         output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
    411         rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
    413 # Sum the rewards from all reward functions

Cell In[4], line 41, in medical_verifier(prompts, completions, answer, **kwargs)
     39 input_batch = reward_tokenizer([text], return_tensors="pt").to(reward_model.device)
     40 with torch.no_grad():
---> 41     logits = reward_model(**input_batch,return_dict=True).logits
     42     probabilities = F.softmax(logits, dim=-1)
     44 reward = 2.0 if probs[0,1] > 0.5 else 0.0

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1736](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1736), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1747](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1747), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:922](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:922), in LlamaForSequenceClassification.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    914 r"""
    915 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
    916     Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
    917     config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
    918     `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
    919 """
    920 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
--> 922 transformer_outputs = self.model(
    923     input_ids,
    924     attention_mask=attention_mask,
    925     position_ids=position_ids,
    926     past_key_values=past_key_values,
    927     inputs_embeds=inputs_embeds,
    928     use_cache=use_cache,
    929     output_attentions=output_attentions,
    930     output_hidden_states=output_hidden_states,
    931     return_dict=return_dict,
    932 )
    933 hidden_states = transformer_outputs[0]
    934 logits = self.score(hidden_states)

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1736](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1736), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1747](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1747), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/unsloth/models/llama.py:868](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/unsloth/models/llama.py:868), in LlamaModel_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
    865     hidden_states = layer_outputs[0]
    867 else:
--> 868     layer_outputs = decoder_layer(
    869         hidden_states,
    870         causal_mask=mask,
    871         attention_mask      = attention_mask,
    872         position_ids        = position_ids,
    873         past_key_value      = past_key_value,
    874         output_attentions   = output_attentions,
    875         use_cache           = use_cache,
    876         padding_mask        = padding_mask,
    877         position_embeddings = position_embeddings,
    878     )
    879     hidden_states = layer_outputs[0]
    880 pass

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1736](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1736), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1747](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1747), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/unsloth/models/llama.py:523](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/unsloth/models/llama.py:523), in LlamaDecoderLayer_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, position_embeddings, *args, **kwargs)
    521 residual = hidden_states
    522 hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
--> 523 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    524     hidden_states       = hidden_states,
    525     causal_mask         = causal_mask,
    526     attention_mask      = attention_mask,
    527     position_ids        = position_ids,
    528     past_key_value      = past_key_value,
    529     output_attentions   = output_attentions,
    530     use_cache           = use_cache,
    531     padding_mask        = padding_mask,
    532     position_embeddings = position_embeddings,
    533 )
    534 hidden_states = residual + hidden_states
    536 # Fully Connected

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1736](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1736), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1747](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1747), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/unsloth/models/llama.py:386](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/unsloth/models/llama.py:386), in LlamaAttention_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, position_embeddings, *args, **kwargs)
    383 head_dim   = self.head_dim
    384 assert(n_kv_heads * n_groups == n_heads)
--> 386 Q, K, V = self.apply_qkv(self, hidden_states)
    387 Q = Q.view(bsz, q_len, n_heads,    head_dim).transpose(1, 2)
    388 K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)

File [~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1931](https://vscode-remote+ood-002dgrace-002eycrc-002eyale-002eedu.vscode-resource.vscode-cdn.net/gpfs/gibbs/project/lu_lu/ll2249/MedDiag/~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1931), in Module.__getattr__(self, name)
   1929     if name in modules:
   1930         return modules[name]
-> 1931 raise AttributeError(
   1932     f"'{type(self).__name__}' object has no attribute '{name}'"
   1933 )

AttributeError: 'LlamaAttention' object has no attribute 'apply_qkv'

So, unsloth overwrites the forward call of a model loaded by the hugging face library, then caused the error. I could switch my reward model to some unsloth-supported model, but I am wondering if this can be solved.

@shimmyshimmer
Copy link
Collaborator

I am trying out the GRPO notebook with a pretrained model as my reward model. Basically, I followed the notebook from this link https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb . I load the base model as the notebook do

model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length = max_seq_length,
load_in_4bit = True, # False for LoRA 16bit
fast_inference = True, # Enable vLLM fast inference
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
model,
r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
], # Remove QKVO if out of memory
lora_alpha = lora_rank,
use_gradient_checkpointing = "unsloth", # Enable long context finetuning
random_state = 3407,
)
Then I load my reward model using the huggingface library:

Load reward tokenizer and model

reward_model_path = 'FreedomIntelligence/medical_o1_verifier_3B'
reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_path)
reward_model = AutoModelForSequenceClassification.from_pretrained(
reward_model_path, torch_dtype="auto", device_map="auto", attn_implementation="flash_attention_2", num_labels=2
)

reward_template = """
{}
</Model Response>

{}

Your task is to evaluate the model response by comparing it to the reference answer. If the model response is correct and aligns with the reference answer, output "True" . If it is incorrect or fails to select the correct option (if options are provided), output "False" . {}"""

def medical_verifier(prompts, completions, answer, **kwargs) -> list[float]:
responses = completions
if answer is None:
return [0.0]*len(responses)
rewards = []
for resp, ref in zip(responses, answer):
text = reward_template.format(resp, ref, reward_tokenizer.eos_token)
input_batch = reward_tokenizer([text], return_tensors="pt").to(reward_model.device)
with torch.no_grad():
logits = reward_model(**input_batch,return_dict=True).logits
probabilities = F.softmax(logits, dim=-1)

    reward = 2.0 if probabilities[0,1] > 0.5 else 0.0
    rewards.append(reward)

return rewards

Then I execute

trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = [
medical_verifier
],
args = training_args,
train_dataset = train_split,
)
trainer.train()
Then I got the following error:

AttributeError Traceback (most recent call last)
Cell In[7], line 10
1 trainer = GRPOTrainer(
2 model = model,
3 processing_class = tokenizer,
(...)
8 train_dataset = train_split,
9 )
---> 10 trainer.train()

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/transformers/trainer.py:2171, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2169 hf_hub_utils.enable_progress_bars()
2170 else:
-> 2171 return inner_training_loop(
2172 args=args,
2173 resume_from_checkpoint=resume_from_checkpoint,
2174 trial=trial,
2175 ignore_keys_for_eval=ignore_keys_for_eval,
2176 )

File :382, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File :25, in _unsloth_training_step(self, model, inputs, num_items_in_batch)

File /gpfs/gibbs/project/lu_lu/ll2249/MedDiag/unsloth_compiled_cache/GRPOTrainer.py:410, in UnslothGRPOTrainer._prepare_inputs(self, inputs)
407 for example in inputs:
408 # Repeat each value in the column for num_generations times
409 reward_kwargs[key].extend([example[key]] * self.num_generations)
--> 410 output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
411 rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
413 # Sum the rewards from all reward functions

Cell In[4], line 41, in medical_verifier(prompts, completions, answer, **kwargs)
39 input_batch = reward_tokenizer([text], return_tensors="pt").to(reward_model.device)
40 with torch.no_grad():
---> 41 logits = reward_model(**input_batch,return_dict=True).logits
42 probabilities = F.softmax(logits, dim=-1)
44 reward = 2.0 if probs[0,1] > 0.5 else 0.0

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:922, in LlamaForSequenceClassification.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
914 r"""
915 labels (torch.LongTensor of shape (batch_size,), optional):
916 Labels for computing the sequence classification/regression loss. Indices should be in [0, ..., 917 config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If
918 config.num_labels > 1 a classification loss is computed (Cross-Entropy).
919 """
920 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
--> 922 transformer_outputs = self.model(
923 input_ids,
924 attention_mask=attention_mask,
925 position_ids=position_ids,
926 past_key_values=past_key_values,
927 inputs_embeds=inputs_embeds,
928 use_cache=use_cache,
929 output_attentions=output_attentions,
930 output_hidden_states=output_hidden_states,
931 return_dict=return_dict,
932 )
933 hidden_states = transformer_outputs[0]
934 logits = self.score(hidden_states)

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/unsloth/models/llama.py:868, in LlamaModel_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
865 hidden_states = layer_outputs[0]
867 else:
--> 868 layer_outputs = decoder_layer(
869 hidden_states,
870 causal_mask=mask,
871 attention_mask = attention_mask,
872 position_ids = position_ids,
873 past_key_value = past_key_value,
874 output_attentions = output_attentions,
875 use_cache = use_cache,
876 padding_mask = padding_mask,
877 position_embeddings = position_embeddings,
878 )
879 hidden_states = layer_outputs[0]
880 pass

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/unsloth/models/llama.py:523, in LlamaDecoderLayer_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, position_embeddings, *args, **kwargs)
521 residual = hidden_states
522 hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
--> 523 hidden_states, self_attn_weights, present_key_value = self.self_attn(
524 hidden_states = hidden_states,
525 causal_mask = causal_mask,
526 attention_mask = attention_mask,
527 position_ids = position_ids,
528 past_key_value = past_key_value,
529 output_attentions = output_attentions,
530 use_cache = use_cache,
531 padding_mask = padding_mask,
532 position_embeddings = position_embeddings,
533 )
534 hidden_states = residual + hidden_states
536 # Fully Connected

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/unsloth/models/llama.py:386, in LlamaAttention_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, position_embeddings, *args, **kwargs)
383 head_dim = self.head_dim
384 assert(n_kv_heads * n_groups == n_heads)
--> 386 Q, K, V = self.apply_qkv(self, hidden_states)
387 Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
388 K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)

File ~/.conda/envs/MedDiag/lib/python3.11/site-packages/torch/nn/modules/module.py:1931, in Module.getattr(self, name)
1929 if name in modules:
1930 return modules[name]
-> 1931 raise AttributeError(
1932 f"'{type(self).name}' object has no attribute '{name}'"
1933 )

AttributeError: 'LlamaAttention' object has no attribute 'apply_qkv'
So, unsloth overwrites the forward call of a model loaded by the hugging face library, then caused the error. I could switch my reward model to some unsloth-supported model, but I am wondering if this can be solved.

yes that is correct. if you want to reuse the HF model call, you have to use reload from import_lib for transformers and then it should work

@DecoderLiu
Copy link
Author

yes that is correct. if you want to reuse the HF model call, you have to use reload from import_lib for transformers and then it should work.

I am using the AutoTokenizer from the hugging face official library. Would you tell me is there any other ways that I can solve this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants