Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for bfloat16 Data Types with PyTorch Automatic Mixed Precision #366

Open
agvico opened this issue Oct 24, 2024 · 0 comments
Open
Labels
0-needs-review 1-feature New feature or request

Comments

@agvico
Copy link

agvico commented Oct 24, 2024

User story

As a user, I want to use bfloat16 data types in lava-dl, with compatibility for PyTorch's torch.amp (Automatic Mixed Precision), to accelerate inference and training processes while maintaining numerical accuracy. This will allow for efficient computation and memory savings, leveraging the mixed precision capabilities of PyTorch to optimize performance for large-scale spiking neural networks (SNNs).

Conditions of satisfaction

  • The software should support bfloat16 data types for all relevant operations, including both training and inference.
  • Integration with torch.amp should be seamless, allowing users to easily switch between float32 and bfloat16 or use automatic mixed precision without significant code changes.
  • The numerical stability and accuracy of operations with bfloat16 should be validated, ensuring compatibility with PyTorch's mixed precision training workflows.
  • Documentation should include guidelines on using bfloat16 with torch.amp, any limitations, and best practices for users.
@agvico agvico added the 1-feature New feature or request label Oct 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
0-needs-review 1-feature New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant