Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Surface more informative error when adapter has NaN weights #81

Closed
arnavgarg1 opened this issue Nov 29, 2023 · 2 comments · Fixed by #168
Closed

Surface more informative error when adapter has NaN weights #81

arnavgarg1 opened this issue Nov 29, 2023 · 2 comments · Fixed by #168
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@arnavgarg1
Copy link
Contributor

Feature request

When querying a base model with an adapter that has NaN or Inf weight tensors, LoRAX returns the following error:

The output tensors do not match for key base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight

It would be more helpful if the error message indicates the the reason the tensors don't match during merge is because LoRAX detected NaN/Inf tensors in the adapter weights.

Motivation

This would help provide a rectifiable/actionable path for users who fine-tuned models and are working on testing them out to know that this isn't an issue with LoRAX, but rather, an issue with their trained adapter weights.

Your contribution

Happy to help surface a better error message! Seems like the issue is raised from this line in particule?

if not torch.equal(pt_tensor, sf_tensor):

@tgaddair tgaddair added enhancement New feature or request good first issue Good for newcomers labels Nov 29, 2023
@asingh9530
Copy link
Contributor

asingh9530 commented Jan 9, 2024

@arnavgarg1 I just looked at it seems addition of following code

if torch.any(torch.isnan(<Yor tensor>))):
    raise ValueError("output tensor contains nan please fix adapter and try again")

in here should fix this but I would argue we should catch custom exception based on type of error we encounter something like this

try:
    # your code
   
except NanError:
    raise NanError("<your message>")
except NotEqualError:
    raise NotEqualError("<your message>")

# simple custom exception class
class CustomError(Exception):
    """Base class for all errors."""


class NanError(CustomError):
    """Exception raised tensor contains nan values"""


class NotEqualError(CustomError):
    """Exception raised when trying to compare values between two tensors """

@arnavgarg1 @tgaddair What do you think guys ?

@tgaddair
Copy link
Contributor

tgaddair commented Jan 9, 2024

Great suggestion, @asingh9530! I put together a quick PR (#168) to test. Let me know if this addresses the issue!

The one thing I'm not sure about is how to handle NaNs if the adapter was written to safetensor format. It seems (based on the error from above) that they're handled differently, so we may need to think about that separately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants