Skip to content

zzmtsvv/mamba-interface

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Mamba clean code in jax and PyTorch

Actually, this is my one-evening attempt to get more handy with jax and flax on the basis of torch implementation on the example of Mamba[1]. It looks more like a somewhat detailed interface of this model that also requires training and inference code. I hope this code will help you become more confident with jax, flax or state-space models[2].

Feel free to contact me on any mistakes you find :) I have also tried to implement associative scan in the jax folder but probably it contains mistakes.

This repo is based on the following ones: annotated-mamba, mamba-minimal in torch, the official implementation

References

[1] - Gu, Dao et al. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces.
[2] Gu et al. (2022). Efficiently Modeling Long Sequences with Structured State Spaces

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages