Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thunderfx produces results with incorrect requires_grad #1733

Open
jjsjann123 opened this issue Feb 1, 2025 · 10 comments
Open

thunderfx produces results with incorrect requires_grad #1733

jjsjann123 opened this issue Feb 1, 2025 · 10 comments
Labels

Comments

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Feb 1, 2025

🐛 Bug

Issue found by @kevinstephano .

import torch
from thunder.dynamo import thunderfx

from transformers import AutoConfig
from transformers.models.qwen2.modeling_qwen2 import Qwen2PreTrainedModel, Qwen2RotaryEmbedding

def config():
    config = AutoConfig.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
    config.batch_size = 1
    config.seq_len = 4096
    config._attn_implementation = "sdpa"
    return config

class MyModel(Qwen2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
        self.rotary_emb = Qwen2RotaryEmbedding(config=config)
        # Initialize weights and apply final processing
        self.post_init()

    def forward(self, input_ids: torch.LongTensor):
        inputs_embeds = self.embed_tokens(input_ids)
        past_seen_tokens = 0
        cache_position = torch.arange(
            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
        )

        position_ids = cache_position.unsqueeze(0)

        cos, sin = self.rotary_emb(inputs_embeds, position_ids)
        return (inputs_embeds, cos, sin,)

cfg = config()
def inputs(dtype, batch_size=cfg.batch_size, seq_len=cfg.seq_len):
            input_ids = torch.randint(0, cfg.vocab_size, (batch_size, seq_len), device='cuda', requires_grad=False)
            return {"input_ids": input_ids}


model = MyModel(cfg).cuda().bfloat16()
eager_out = model(**inputs(torch.bfloat16))

thunder_model = thunderfx(model)
thunder_out = thunder_model(**inputs(torch.bfloat16))

for idx in range(len(eager_out)):
    print("eager requires_grad?", idx, eager_out[idx].requires_grad)
for idx in range(len(thunder_out)):
    print("thunder requires_grad?", idx, thunder_out[idx].requires_grad)

returns

eager requires_grad? 0 True
eager requires_grad? 1 False
eager requires_grad? 2 False
thunder requires_grad? 0 True
thunder requires_grad? 1 True
thunder requires_grad? 2 True
@mruberry
Copy link
Collaborator

mruberry commented Feb 3, 2025

fyi @IvanYashchuk

@IvanYashchuk
Copy link
Collaborator

Thunder is using torch.autograd.Function to register Thunder-generated backward with PyTorch's Autograd. It's PyTorch's Autograd behavior to set requires_grad=True and grad_fn attributed to all outputs of torch.autograd.Function.apply:

import torch
from transformers import AutoConfig
from transformers.models.qwen2.modeling_qwen2 import Qwen2PreTrainedModel, Qwen2RotaryEmbedding

def config():
    config = AutoConfig.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
    config.batch_size = 1
    config.seq_len = 4096
    config._attn_implementation = "sdpa"
    return config

class Test(torch.autograd.Function):
    @staticmethod
    def forward(ctx, rotary_emb, inputs_embeds, position_ids):
        cos, sin = rotary_emb(inputs_embeds, position_ids)
        return (inputs_embeds, cos, sin,)

class MyModel(Qwen2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
        self.rotary_emb = Qwen2RotaryEmbedding(config=config)
        # Initialize weights and apply final processing
        self.post_init()

    def forward(self, input_ids: torch.LongTensor):
        inputs_embeds = self.embed_tokens(input_ids)
        past_seen_tokens = 0
        cache_position = torch.arange(
            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
        )

        position_ids = cache_position.unsqueeze(0)

        #cos, sin = self.rotary_emb(inputs_embeds, position_ids)
        return Test.apply(self.rotary_emb, inputs_embeds, position_ids)

cfg = config()
def inputs(dtype, batch_size=cfg.batch_size, seq_len=cfg.seq_len):
            input_ids = torch.randint(0, cfg.vocab_size, (batch_size, seq_len), device='cuda', requires_grad=False)
            return {"input_ids": input_ids}


model = MyModel(cfg).cuda().bfloat16()
eager_out = model(**inputs(torch.bfloat16))

for idx in range(len(eager_out)):
    print("eager requires_grad?", idx, eager_out[idx].requires_grad)

The above code snippet prints:

eager requires_grad? 0 True
eager requires_grad? 1 True
eager requires_grad? 2 True

And grad_fn points to the same object:

In [3]: [e.grad_fn for e in eager_out]
Out[3]: 
[<torch.autograd.function.TestBackward at 0x7efcc9c24bf0>,
 <torch.autograd.function.TestBackward at 0x7efcc9c24bf0>,
 <torch.autograd.function.TestBackward at 0x7efcc9c24bf0>]

@IvanYashchuk
Copy link
Collaborator

It's possible to mark certain outputs to have requires_grad=False with https://pytorch.org/docs/stable/generated/torch.autograd.function.FunctionCtx.mark_non_differentiable.html

How important is it to have the same requires_grad as in eager? Is there an example of when this hurts performance?

@jjsjann123
Copy link
Collaborator Author

^^^ tagging @kevinstephano

@jjsjann123
Copy link
Collaborator Author

Ha I think it's important to correctly propagate reuiqres_grad in thunder trace then.

A small patch for that. #1725

@IvanYashchuk
Copy link
Collaborator

Why do you think it's important to propagate this attribute? Does it provide performance benefits?

Currently requires_grad is used only on Tensor inputs and ignored on intermediates and outputs and never used nor tested. Starting to maintain this feature can be a large undertaking.

@jjsjann123
Copy link
Collaborator Author

It's possible to mark certain outputs to have requires_grad=False with https://pytorch.org/docs/stable/generated/torch.autograd.function.FunctionCtx.mark_non_differentiable.html

How important is it to have the same requires_grad as in eager? Is there an example of when this hurts performance?

I thought in order to mark outputs to have requires_grad=False, we needed to propagate that field correctly?
Or is there a different mechanism that we would use in order to know which output should get a requires_grad flag set?

@IvanYashchuk
Copy link
Collaborator

You're correct that to mark the outputs to have requires_grad=False we need to propagate this attribute. My question is whether we need to mark those outputs correctly. What do we get for doing this? Do we have to do it?

@mruberry
Copy link
Collaborator

You're correct that to mark the outputs to have requires_grad=False we need to propagate this attribute. My question is whether we need to mark those outputs correctly. What do we get for doing this? Do we have to do it?

Being consistent with PyTorch is probably the right thing to do if we don't have a compelling reason to do otherwise. That said, it's probably not a high priority to be completely consistent unless we find customer scenarios impacted by the inconsistency.

@jjsjann123
Copy link
Collaborator Author

You're correct that to mark the outputs to have requires_grad=False we need to propagate this attribute. My question is whether we need to mark those outputs correctly. What do we get for doing this? Do we have to do it?

My very naive understanding here is that, if we mess up requires_grad on outputs, then in a hybrid system where tensor goes in and out of thunder world, this could affect how downstream backward graph is going to look like.
I'm uncertain if there is a real world impact on this.

Does thunderfx produce segmented thunder section? would this issue manifest in that scenario? @IvanYashchuk wondering if you could give any insight since I don't know how that system works.

Also tagging @kevinstephano wondering if you have seen this causing a real issue with your benchmarks. ( and I remember you switched to thunderfx already ).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants