-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Charged Fragment Types to MS2 Model Weights #228
base: development
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really good and clean solution for the problem.
I would suggest to refine the notebook and PR description a bit to make it more clear 1) what is the problem 2) what does this PR implement 3) why is this a solution to the problem. This could be sections or headings etc.
] | ||
}, | ||
{ | ||
"cell_type": "markdown", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to make more clear, that we will first show the old format and how the old weight format has lead to issued:
So three examples:
- Request
charged_frag_types
as in weights and loading goes fine - Request
charged_frag_types
with different shape and loading fails with weight missmatch - Request
charged_frag_types
with same shape (but different types or order) and the model pretends it works but is acually broke (worst case).
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## User importing weighst in the new format " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I would be more clear. Benefits of the new parameter format shown with different use-cases.
"model = pDeepModel(charged_frag_types=['b_z1', 'b_z2', 'y_z1', 'y_z2'])\n", | ||
"model.load(new_path)\n", | ||
"print(f\"Model Interface has charged_frag_types {model.charged_frag_types}\")\n", | ||
"print(f\"Supported charged_frag_types in the loaded weights {model.model.supported_charged_frag_types}\")" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model.model
interface feels a bit strange. could we improe the naming?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, but changing the attribute name in the model interface would be a breaking change, for example in alphaDIA FinetuneManager
I changed the variable name used in the notebook to be more clear.
], | ||
"source": [ | ||
"model = pDeepModel(\n", | ||
" charged_frag_types=['c_z1', 'c_z2', 'y_z1', 'y_z2', 'x_z1', 'x_z2', 'y_modloss_z1', 'y_modloss_z2'], # Will be overridden by the model weights\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we leave the charged_frag_types
empty to emphasize this?
charged_frag_types = []
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not completely empty since this initializes a pytorch model but I changed it to a single frag type and added print statements to emphasize this use case.
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add an example that the order is automatically controlled by alphabase?
Even if the user requests a random order both the saved and predicted frag types are always in a defined order by sort_charged_frag_types
.
**kwargs, # model params | ||
): | ||
super().__init__(device=device) | ||
if mask_modloss is not None: | ||
warnings.warn("mask_modloss is deprecated and will be removed in the future. To mask the modloss fragments, the charged_frag_types should not include the modloss fragments.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does mask_modloss
still work in this version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and also is one of test cases
peptdeep/model/ms2.py
Outdated
device : str, optional | ||
Device to run the model, by default "gpu" | ||
override_from_weights : bool, optional default False | ||
Over ride the requested charged frag types when loading model from weights, this will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe make it more clear what this is doing.
Always update the requested charged fragment types from the model weights on loading. This allows to predict all fragment types supported by the weights even if the user doesn't know what fragments types are supported by the weights. Thereby, the model will always be in a safe to predict state.
@@ -504,6 +530,41 @@ def _set_batch_predict_data( | |||
predicts.reshape((-1, len(self.charged_frag_types))), | |||
batch_df[["frag_start_idx", "frag_stop_idx"]].values, | |||
) | |||
def _align_model_charged_frag_types(self): | |||
""" | |||
Align the underlying model charged_frag_types with the interface charged_frag_types, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are interface charged frag types the requested ones?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like the impact can be more clear. align
sounds very wague.
Maybe something like reshape_prediction_head
?
@@ -587,6 +655,10 @@ def predict( | |||
reference_frag_df=None, | |||
**kwargs, | |||
) -> pd.DataFrame: | |||
if not self._safe_to_predict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please correct me, but shouldn't the most obvious solution be that users should select different (supported) charged frag types? I would mention this first and then using different weights and retraining.
f"nn parameters {unexpect_keys} are UNEXPECTED while loading models in {self.__class__}" | ||
) | ||
|
||
def _update_model_state(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could this lead to deviation of _safe_to_predict
and the reality?
Do you think it would make sense to do this as part of a @Property?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we can, the only use case that's preventing that is:
- Model loaded and is ( safe to predict and not safe to train).
reshape the prediction head
called. So now both supported and requested frag types are aligned which makes the model technically safe to predict, and having the state update as part of the getters will allow prediction with a randomly initialized prediction head even if the training didn't start (ex. crashed).
Currently, the model goes from 'not safe to predict' to 'safe to predict' only after the training completes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see 👍🏻
@@ -884,3 +995,31 @@ def calc_ms2_similarity( | |||
|
|||
torch.cuda.empty_cache() | |||
return psm_df, metrics_describ | |||
|
|||
|
|||
def charged_frags_to_tensor(charged_frags: List[str]) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this looks very nice.
This PR introduces a small but useful modification to the MS2 model interface and
ModelMS2Bert
,by including charged fragment types as an attribute within the model weights. With this change, charged fragment types are saved and loaded from disk with the weights, making it easier to adjust the requested fragment types for prediction and training.Main Benefits
mask_modloss
argument.I've added a notebook (
adapt_charged_fragtypes.ipynb
) that walks through different use cases, including compatibility with older weight formats. It would be great if you could take a look and let me know if anything is missing.This update modifies
ModelMS2Bert
, but if this approach looks good, I’d be happy to extend it to other architectures in a follow-up PR.Supported use cases with the new format:
(1) The ideal use case (only one supported by the old implementation) where users know and request exactly the same farg types supported in the model weights.
(2) Users only need to predict a subset of the frag types supported by the loaded weights.
(3) Users requested charged frag types that are not supported (can easily be identified now since we don't only look at the number of requested frag types).
(*)
Override from weights
is the new argument added to the MS2 model, this allow users to load models without knowing exactly what are the supported frag types in a pretrained model. So this overrides the requested frag types and uses all supported frag types by the loaded model.Safe to train: Any model is automatically set to 'safe to train' when the
train
function is called-- by modifying only the underlying model output head to align with the requested frag types. After training, the model will automatically be in a 'safe to predict' state.