Skip to content

Multi Scale Deformable Attention Support #112827

@d4l3k

Description

@d4l3k

🚀 The feature, motivation and pitch

Multi-scale deformable attention has gained traction in many recent birdseye view and 3d model papers. It provides a lot of performance improvements over doing full attention as it samples a subset of the possible queries rather than computing attention across all keys.

Example Papers:

There's a handful of fragmented implementations available. It would be great to have this be usptreamed to PyTorch core given the number of papers using it now.

These existing implementations have a lot of issues such as not supporting different data types and not using the torch ops registration.

Example Implementations

The existing implementations are licensed under Apache 2.0 -- is it possible to upstream as is or would it require a complete rewrite or relicensing under BSD to match PT core?

Alternatives

Installing a third party library from source or using mmcv which has many many dependencies

Additional context

No response

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @bhosmer @cpuhrsch @erichan1 @drisspg

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: nnRelated to torch.nnmodule: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    To pick up

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions