Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SyntaxError when patching SFTTrainer in unsloth/tokenizer_utils.py #1698

Open
TobiAdeniji94 opened this issue Feb 13, 2025 · 9 comments
Open
Labels
fixed - pending confirmation Fixed, waiting for confirmation from poster

Comments

@TobiAdeniji94
Copy link

TobiAdeniji94 commented Feb 13, 2025

Description:
I encountered the following error while running FastLanguageModel.from_pretrained with unsloth/Meta-Llama-3.1-8B:

First code block:
%%capture

!pip install "unsloth [colab-new] @git+https://github.com/unslothai/unsloth.git"

import torch
from packaging.version import Version as V
xformers = "xformers-0.0.27" if V(torch.version) < V("2.4.0") else "xformers"

!pip install --no-deps {xformers} trl peft accelerate bitsandbytes triton

Second code block:
from unsloth import FastLanguageModel
import torch

max_seq_length = 2048
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Meta-Llama-3.1-8B",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)

Traceback:
File "unsloth/tokenizer_utils.py", line 1061, in
exec(trainer_text, globals())
File "", line 4
[invalid syntax here]

RuntimeError: Unsloth: Please file a bug report! Error patching SFTTrainer

Environment:

  • Python version: 3
  • PyTorch version:
  • Unslooth version:
  • Hardware: T4 GPU (Google Colab)

Steps to reproduce:

  1. Run the provided code snippet.
  2. The error occurs during the dynamic patching of SFTTrainer.
@edoproch
Copy link

Same error, I opened an issue too.
It seems that also the example notebooks on colab do not work.
I tried vast, colab and to install in various ways (also specifying my cuda and torch versions ) but nothing seems to work, any update will be very appreciated!

@brunodoamaral
Copy link

I added a print to see the value of trainer_text and the result is bellow. The 4th line is not a valid python syntax.

class UnslothSFTTrainer(SFTTrainer):
    def __init__(
        self,
        model = <class 'inspect._empty'>,  # 4th line
        args = None,
        data_collator = None,
        train_dataset = None,
        eval_dataset = None,
        processing_class = None,
        compute_loss_func = None,
        compute_metrics = None,
        callbacks = None,
        optimizers = (None, None),
        optimizer_cls_and_kwargs = None,
        preprocess_logits_for_metrics = None,
        peft_config = None,
        formatting_func = None,
        tokenizer = None):
        
        super().__init__(
            model = model,
            args = args,
            data_collator = data_collator,
            train_dataset = train_dataset,
            eval_dataset = eval_dataset,
            compute_loss_func = compute_loss_func,
            compute_metrics = compute_metrics,
            callbacks = callbacks,
            optimizers = optimizers,
            optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
            preprocess_logits_for_metrics = preprocess_logits_for_metrics,
            peft_config = peft_config,
            formatting_func = formatting_func,
            processing_class = tokenizer if tokenizer else processing_class
        )

@brunodoamaral
Copy link

This seems to be related to new trl version (0.15.0), which has just being released.

I was able to fix downgrading it: pip install "trl<0.15.0"

@mitzenjeremywoo
Copy link

thanks! the solution works for me

@kkailaasa
Copy link

when trying to run !pip install "trl<0.15.0", getting is error:


AttributeError                            Traceback (most recent call last)

[<ipython-input-5-74a5cc0ecc2b>](https://localhost:8080/#) in <cell line: 0>()
----> 1 from unsloth import FastLanguageModel
      2 import torch
      3 from datasets import load_dataset
      4 from trl import SFTTrainer
      5 from transformers import TrainingArguments

13 frames

[/usr/local/lib/python3.11/dist-packages/torchvision/_meta_registrations.py](https://localhost:8080/#) in wrapper(fn)
     16 def register_meta(op_name, overload_name="default"):
     17     def wrapper(fn):
---> 18         if torchvision.extension._has_ops():
     19             get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
     20         return fn

AttributeError: partially initialized module 'torchvision' has no attribute 'extension' (most likely due to a circular import)

@danielhanchen
Copy link
Contributor

Working on a fix asap - sorry on the issue!

@brunodoamaral
Copy link

@danielhanchen I was really surprised to discover that unsloth patch some classes via string replacement and exec(str, globals()). I’m not sure how the team had come to this solution, but it is very unstable… like this bug caused by a dependency update.
Do you have anywhere explaining why this approach was chosen?

@danielhanchen danielhanchen added fixed - pending confirmation Fixed, waiting for confirmation from poster and removed currently fixing Am fixing now! labels Feb 13, 2025
@danielhanchen
Copy link
Contributor

Just fixed - for now please use TRL < 0.15.0 - I updated all notebooks as well - so refresh them if you're using them. Ie do:

pip uninstall trl -y && pip install --no-cache-dir --force-reinstall --no-deps "trl<0.15.0"

I'm still trying to work on supporting the latest TRL 0.15.0, so it'll take a bit more time.

For GRPO runs, please use "unsloth==2025.2.4" so pip install "unsloth==2025.2.4"

Using pip install unsloth should work since it'll use an old TRL version

@danielhanchen
Copy link
Contributor

@brunodoamaral The approach was taken since it's the best for patching - re-writing entire swathes of code (or even copying pasting then overwriting) was one of the options, but we decided it was extremely time consuming to handle.

We normally collaborate directly with the Hugging Face on launches, so this one sadly in terms of schedule got out of whack

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fixed - pending confirmation Fixed, waiting for confirmation from poster
Projects
None yet
Development

No branches or pull requests

6 participants