Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added custom mamba op and fix the mamba cache issue #1521

Closed
wants to merge 3 commits into from

Conversation

zzhang37
Copy link

What does this PR do?

Added custom mamba pscan op and fixed the mamba cache issue

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@zzhang37 zzhang37 requested a review from regisss as a code owner November 22, 2024 18:47
Copy link
Collaborator

@libinta libinta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zzhang37 please update test to reflect the change

from transformers.utils import (
ModelOutput,
logging,
)

from pathlib import Path
import os
base_dir = "/workspace/custom_op_pscan_all"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@libinta what would be the best way to set this without hardcoding?
Atleast an env var?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this dir generated on the fly? Or is it supposed to be downloaded (e.g. as part of an example) ?

Copy link
Author

@zzhang37 zzhang37 Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change based on our relative folder location

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added HABANA_CUSTOM_OP_DIR for custom op lib folder or using the current folder as lib folder.

A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
and is why Mamba is called **selective** state spaces)
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zzhang37 can you plz add a comment in the code about the different between this and original impl?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


# fmt: off
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
batch_size, seq_len, _ = input_states.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zzhang37 , can u plz add a brief code comment about the difference between this and original.

is it only Run_Mamba_Forward_Gaudi ?

@jiminha
Copy link
Collaborator

jiminha commented Nov 26, 2024

@zzhang37 Also, all synapse dependencies merged in to 1.19 release?

Copy link
Collaborator

@libinta libinta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update readme front page, text-gen to add cmd and test case?

@regisss
Copy link
Collaborator

regisss commented Dec 9, 2024

@zzhang37 What's the difference between this PR and #1573 ?

@zzhang37
Copy link
Author

zzhang37 commented Dec 9, 2024

@regisss This PR is no longer needed. Please remove it. Thanks. We only need PR #1573.

@regisss regisss closed this Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants