Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS acceleration + TRBV genes mice #35

Open
celinebalaa opened this issue Jan 23, 2025 · 4 comments
Open

MPS acceleration + TRBV genes mice #35

celinebalaa opened this issue Jan 23, 2025 · 4 comments
Assignees
Labels
enhancement New feature or request

Comments

@celinebalaa
Copy link

This is such a great tool—easy to install, intuitive, and well-documented. It’s already super fast, but to enhance performance further for users on Apple Silicon Macs, would you consider adding support for Metal Performance Shaders (MPS) acceleration? Here’s a simple check to integrate MPS if available:
if torch.backends.mps.is_available():
device = torch.device("mps")

Additionally, are the TRBV data sets included or supported by the tool specific to humans, or do they also support mice?

Thanks again for your hard work!

@yutanagano yutanagano self-assigned this Jan 23, 2025
@yutanagano
Copy link
Owner

Hi @celinebalaa, thanks for using SCEPTR! I'm glad you find it easy to use. Yes, I'd be happy to add support for MPS, and no, currently Mus musculus (I assume this it what you mean by mouse?) V genes are not supported. But, I don't think it should be too difficult to add in mouse support, and I am happy to do it. I will make two separate subissues regarding these two points now -- in the meantime if you would like to apply the model on mouse data you can just use the CDR3 sequences. You can use the CDR3 only model variant, or you could even just use the default model but only feed it the CDR3 sequences. Let me know if there's anything else I can help with!

@yutanagano yutanagano added the enhancement New feature or request label Jan 23, 2025
@yutanagano
Copy link
Owner

@celinebalaa, I think I have managed to add in support for MPS. I have merged in the new code to the main branch. I was wondering if you would be willing to test if it does indeed work? (I do not own any MPS devices 💀)...

To do this, could you please:

  1. Create a fresh python environment using venv or conda or any other method of your choosing.
  2. Install sceptr from source in this new environment with
pip install git+https://github.com/yutanagano/sceptr.git
  1. Run the following little script:
import logging
import pandas as pd
import sceptr

logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.DEBUG)

tcrs = pd.DataFrame(
    data = {
            "TRAV": ["TRAV38-1*01", "TRAV3*01", "TRAV13-2*01", "TRAV38-2/DV8*01"],
            "CDR3A": ["CAHRSAGGGTSYGKLTF", "CAVDNARLMF", "CAERIRKGQVLTGGGNKLTF", "CAYRSAGGGTSYGKLTF"],
            "TRBV": ["TRBV2*01", "TRBV25-1*01", "TRBV9*01", "TRBV2*01"],
            "CDR3B": ["CASSEFQGDNEQFF", "CASSDGSFNEQFF", "CASSVGDLLTGELFF", "CASSPGTGGNEQYF"],
    },
    index = [0,1,2,3]
)

_ = sceptr.calc_vector_representations(tcrs)

sceptr.disable_hardware_acceleration()
sceptr.enable_hardware_acceleration()
  1. Running the previous script should have generated an example.log text file. Could you please paste its contents here? This should allow us to verify whether sceptr is smart enough to detect the MPS GPU cores on your machine to run its models on it.

@celinebalaa
Copy link
Author

celinebalaa commented Jan 26, 2025

Hi @yutanagano, thank you for taking the time to add this feature and your quick replies! I really appreciate your efforts in addressing this.

I installed the updated version from GitHub in a new virtual environment and tested it with a Python script. The log indicates that SCEPTR successfully detected the MPS device and set the torch device to 'mps':

DEBUG:sceptr._model_saves:Loading SCEPTR variant: SCEPTR
DEBUG:sceptr.model:enable_hardware_acceleration called on <sceptr.model.Sceptr object at 0x127c340b0> (SCEPTR), setting device to mps

However, I encountered an error when running sceptr.calc_vector_representations(). The operation aten::_nested_tensor_from_mask_left_aligned is not currently implemented for MPS.

Here is the full error:

_ = sceptr.calc_vector_representations(tcrs)
Traceback (most recent call last):
File "", line 1, in
File "/Users/celinebalaa/.local/lib/python3.12/site-packages/sceptr/init.py", line 78, in calc_vector_representations
return _get_default_model().calc_vector_representations(instances)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/celinebalaa/.local/lib/python3.12/site-packages/sceptr/model.py", line 176, in calc_vector_representations
torch_representations = self._calc_torch_representations(instances)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/celinebalaa/.local/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/celinebalaa/.local/lib/python3.12/site-packages/sceptr/model.py", line 268, in _calc_torch_representations
batch_representation = self._bert.get_vector_representations_of(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/celinebalaa/.local/lib/python3.12/site-packages/libtcrlm/bert.py", line 36, in get_vector_representations_of
self._vector_representation_delegate.get_vector_representations_of(
File "/Users/celinebalaa/.local/lib/python3.12/site-packages/libtcrlm/vector_representation_delegate.py", line 72, in get_vector_representations_of
final_token_embeddings = self._self_attention_stack.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/celinebalaa/.local/lib/python3.12/site-packages/libtcrlm/self_attention_stack.py", line 112, in forward
return self._standard_stack.forward(projected_embeddings, padding_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/celinebalaa/.local/lib/python3.12/site-packages/libtcrlm/self_attention_stack.py", line 59, in forward
return self._self_attention_stack.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/celinebalaa/.local/lib/python3.12/site-packages/torch/nn/modules/transformer.py", line 454, in forward
) and not torch._nested_tensor_from_mask_left_aligned(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: The operator 'aten::_nested_tensor_from_mask_left_aligned' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on pytorch/pytorch#77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

I tried setting PYTORCH_ENABLE_MPS_FALLBACK=1 and tested the code again. It is working; however, it appears to be slower than running on the CPU. I am unsure whether this is due to the CPU fallback mechanism or another factor, such as the batch size.

From my experience, smaller batch sizes (e.g., less than 128) often result in slower performance on MPS compared to the CPU. However, I believe the default batch size in this case is 512, so batch size shouldn’t be the issue.
(For additional context, both ESM2 and TCR-BERT models are working with MPS on my device without any significant performance issues.)

Here is the code snippet for time measurement:

`def measure_time_with_acceleration(acceleration_enabled):
if acceleration_enabled:
sceptr.enable_hardware_acceleration()
print("Hardware acceleration enabled.")
else:
sceptr.disable_hardware_acceleration()
print("Hardware acceleration disabled.")

# Measure execution time using perf_counter
start = time.perf_counter()
_ = sceptr.calc_vector_representations(tcrs)
elapsed = time.perf_counter() - start
print(f"Time with {'MPS (fallback enabled)' if acceleration_enabled else 'CPU'}: {elapsed:.4f} seconds")
return elapsed

`

And the output:

CPU Time: 0.0095 seconds
MPS Time: 0.1515 seconds

If you’re working on a potential fix or update, I’d be happy to test it on my setup and provide feedback. Please let me know if I can assist further!

@yutanagano
Copy link
Owner

Hey @celinebalaa, thanks for testing out the code! Sorry to hear that it didn't end up in a speed improvement just yet. And thanks to your detailed reporting it seems quite clear what the underlying issue is, which is the lack of the nested tensor generation function inside of the pytorch transformer library.

If you’re working on a potential fix or update, I’d be happy to test it on my setup and provide feedback. Please let me know if I can assist further!

Thanks for this - unfortunately I there isn't anything I can do to fix this at the moment. Pytorch has not yet implemented all the necessary operations for the MPS hardware, and it's above my paygrade to implement it myself and submit it to upstream Pytorch. I checked in pytorch/pytorch#77764 if the particular operation aten::_nested_tensor_from_mask_left_aligned is tracked as something that needs to be implemented and it indeed is, so we will just have to wait to get full MPS acceleration support. Sorry I couldn't help further.

I will delay pushing support for MPS until Pytorch reliably supports MPS. In the meantime I'll find some time to look into experimental support for Mus musculus V genes. Also, if you find out that Pytorch support for MPS has improved and might be ready to test again, shoot me a message or submit another issue and we can look at it together! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants