-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
thunderfx produces results with incorrect requires_grad
#1733
Comments
fyi @IvanYashchuk |
Thunder is using import torch
from transformers import AutoConfig
from transformers.models.qwen2.modeling_qwen2 import Qwen2PreTrainedModel, Qwen2RotaryEmbedding
def config():
config = AutoConfig.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
config.batch_size = 1
config.seq_len = 4096
config._attn_implementation = "sdpa"
return config
class Test(torch.autograd.Function):
@staticmethod
def forward(ctx, rotary_emb, inputs_embeds, position_ids):
cos, sin = rotary_emb(inputs_embeds, position_ids)
return (inputs_embeds, cos, sin,)
class MyModel(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
# Initialize weights and apply final processing
self.post_init()
def forward(self, input_ids: torch.LongTensor):
inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
position_ids = cache_position.unsqueeze(0)
#cos, sin = self.rotary_emb(inputs_embeds, position_ids)
return Test.apply(self.rotary_emb, inputs_embeds, position_ids)
cfg = config()
def inputs(dtype, batch_size=cfg.batch_size, seq_len=cfg.seq_len):
input_ids = torch.randint(0, cfg.vocab_size, (batch_size, seq_len), device='cuda', requires_grad=False)
return {"input_ids": input_ids}
model = MyModel(cfg).cuda().bfloat16()
eager_out = model(**inputs(torch.bfloat16))
for idx in range(len(eager_out)):
print("eager requires_grad?", idx, eager_out[idx].requires_grad) The above code snippet prints: eager requires_grad? 0 True
eager requires_grad? 1 True
eager requires_grad? 2 True And
|
It's possible to mark certain outputs to have How important is it to have the same requires_grad as in eager? Is there an example of when this hurts performance? |
^^^ tagging @kevinstephano |
Ha I think it's important to correctly propagate A small patch for that. #1725 |
Why do you think it's important to propagate this attribute? Does it provide performance benefits? Currently |
I thought in order to mark outputs to have |
You're correct that to mark the outputs to have |
Being consistent with PyTorch is probably the right thing to do if we don't have a compelling reason to do otherwise. That said, it's probably not a high priority to be completely consistent unless we find customer scenarios impacted by the inconsistency. |
My very naive understanding here is that, if we mess up Does thunderfx produce segmented thunder section? would this issue manifest in that scenario? @IvanYashchuk wondering if you could give any insight since I don't know how that system works. Also tagging @kevinstephano wondering if you have seen this causing a real issue with your benchmarks. ( and I remember you switched to thunderfx already ). |
🐛 Bug
Issue found by @kevinstephano .
returns
The text was updated successfully, but these errors were encountered: