-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MPS acceleration + TRBV genes mice #35
Comments
Hi @celinebalaa, thanks for using SCEPTR! I'm glad you find it easy to use. Yes, I'd be happy to add support for MPS, and no, currently Mus musculus (I assume this it what you mean by mouse?) V genes are not supported. But, I don't think it should be too difficult to add in mouse support, and I am happy to do it. I will make two separate subissues regarding these two points now -- in the meantime if you would like to apply the model on mouse data you can just use the CDR3 sequences. You can use the CDR3 only model variant, or you could even just use the default model but only feed it the CDR3 sequences. Let me know if there's anything else I can help with! |
@celinebalaa, I think I have managed to add in support for MPS. I have merged in the new code to the main branch. I was wondering if you would be willing to test if it does indeed work? (I do not own any MPS devices 💀)... To do this, could you please:
pip install git+https://github.com/yutanagano/sceptr.git
import logging
import pandas as pd
import sceptr
logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.DEBUG)
tcrs = pd.DataFrame(
data = {
"TRAV": ["TRAV38-1*01", "TRAV3*01", "TRAV13-2*01", "TRAV38-2/DV8*01"],
"CDR3A": ["CAHRSAGGGTSYGKLTF", "CAVDNARLMF", "CAERIRKGQVLTGGGNKLTF", "CAYRSAGGGTSYGKLTF"],
"TRBV": ["TRBV2*01", "TRBV25-1*01", "TRBV9*01", "TRBV2*01"],
"CDR3B": ["CASSEFQGDNEQFF", "CASSDGSFNEQFF", "CASSVGDLLTGELFF", "CASSPGTGGNEQYF"],
},
index = [0,1,2,3]
)
_ = sceptr.calc_vector_representations(tcrs)
sceptr.disable_hardware_acceleration()
sceptr.enable_hardware_acceleration()
|
Hi @yutanagano, thank you for taking the time to add this feature and your quick replies! I really appreciate your efforts in addressing this. I installed the updated version from GitHub in a new virtual environment and tested it with a Python script. The log indicates that SCEPTR successfully detected the MPS device and set the torch device to 'mps': DEBUG:sceptr._model_saves:Loading SCEPTR variant: SCEPTR However, I encountered an error when running sceptr.calc_vector_representations(). The operation aten::_nested_tensor_from_mask_left_aligned is not currently implemented for MPS. Here is the full error: _ = sceptr.calc_vector_representations(tcrs) I tried setting PYTORCH_ENABLE_MPS_FALLBACK=1 and tested the code again. It is working; however, it appears to be slower than running on the CPU. I am unsure whether this is due to the CPU fallback mechanism or another factor, such as the batch size. From my experience, smaller batch sizes (e.g., less than 128) often result in slower performance on MPS compared to the CPU. However, I believe the default batch size in this case is 512, so batch size shouldn’t be the issue. Here is the code snippet for time measurement: `def measure_time_with_acceleration(acceleration_enabled):
` And the output: CPU Time: 0.0095 seconds If you’re working on a potential fix or update, I’d be happy to test it on my setup and provide feedback. Please let me know if I can assist further! |
Hey @celinebalaa, thanks for testing out the code! Sorry to hear that it didn't end up in a speed improvement just yet. And thanks to your detailed reporting it seems quite clear what the underlying issue is, which is the lack of the nested tensor generation function inside of the pytorch transformer library.
Thanks for this - unfortunately I there isn't anything I can do to fix this at the moment. Pytorch has not yet implemented all the necessary operations for the MPS hardware, and it's above my paygrade to implement it myself and submit it to upstream Pytorch. I checked in pytorch/pytorch#77764 if the particular operation I will delay pushing support for MPS until Pytorch reliably supports MPS. In the meantime I'll find some time to look into experimental support for Mus musculus V genes. Also, if you find out that Pytorch support for MPS has improved and might be ready to test again, shoot me a message or submit another issue and we can look at it together! :) |
This is such a great tool—easy to install, intuitive, and well-documented. It’s already super fast, but to enhance performance further for users on Apple Silicon Macs, would you consider adding support for Metal Performance Shaders (MPS) acceleration? Here’s a simple check to integrate MPS if available:
if torch.backends.mps.is_available():
device = torch.device("mps")
Additionally, are the TRBV data sets included or supported by the tool specific to humans, or do they also support mice?
Thanks again for your hard work!
The text was updated successfully, but these errors were encountered: