Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linear/conv op dispatch #108

Closed
wants to merge 2 commits into from
Closed

linear/conv op dispatch #108

wants to merge 2 commits into from

Conversation

ASEM000
Copy link
Owner

@ASEM000 ASEM000 commented Apr 11, 2024

  • Enable defining new rules for different weight types other than jax.Array, (e.g. LoRAWeight) on parameter-level of the linear/conv operations, rather than Class-level (e.g. LoRALinear).

Pros

  • Parameter level approach enables better composition with at surgery.

    import serket as sk
    import functools as ft
    import re
    net = ...
    
    class CustomWeight(sk.TreeClass):
         ...
    
    @sk.nn.linear.def_type(CustomWeight)
    def _(....):
        # how to handle CustomWeight for linear
        # all layers (and their composition) depending on sk.nn.Linear/sk.nn.linear will be supported by default
    
    # select weights of all linear layers, and apply the CustomWeight
    net = net.at[re.compile("linear_*.")]["weight"].apply(ft.partial(CustomWeight, ...))
  • Better Debugging since everything is on the python level

Cons:

Less general than tools operates at jax-level ({initial,final} style) like lorax that can operate on jax progams. However its risky to develop anything with non-public jax API, additionally staged out initial style is harder to debug.

Lets give this a try.

@ASEM000 ASEM000 closed this Apr 11, 2024
@ASEM000 ASEM000 deleted the per-weight-dispatch branch April 16, 2024 12:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant