Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Charged Fragment Types to MS2 Model Weights #228

Draft
wants to merge 8 commits into
base: development
Choose a base branch
from

Conversation

mo-sameh
Copy link
Collaborator

@mo-sameh mo-sameh commented Feb 12, 2025

This PR introduces a small but useful modification to the MS2 model interface and ModelMS2Bert,by including charged fragment types as an attribute within the model weights. With this change, charged fragment types are saved and loaded from disk with the weights, making it easier to adjust the requested fragment types for prediction and training.

Main Benefits

  • More flexibility in training and inference
    • Supports partial weight loading to continue training with different fragment types (as suggested in PR allow partial loading for pre trained ms2 models #226).
    • Allows prediction with any subset of supported fragment types, including masking modloss fragments—removing the need for a separate mask_modloss argument.
  • Better handling of unsupported fragment types
    • Now raises clear and interpretable errors when an invalid fragment type is requested.

I've added a notebook (adapt_charged_fragtypes.ipynb) that walks through different use cases, including compatibility with older weight formats. It would be great if you could take a look and let me know if anything is missing.

This update modifies ModelMS2Bert, but if this approach looks good, I’d be happy to extend it to other architectures in a follow-up PR.

Supported use cases with the new format:

Fragtypes use case Override from weights (*) Safe to predict Safe to train
requested = supported (1) False
requested ⊆ supported (2) False
requested ⊈ supported (3) False
Any True

(1) The ideal use case (only one supported by the old implementation) where users know and request exactly the same farg types supported in the model weights.
(2) Users only need to predict a subset of the frag types supported by the loaded weights.
(3) Users requested charged frag types that are not supported (can easily be identified now since we don't only look at the number of requested frag types).
(*) Override from weights is the new argument added to the MS2 model, this allow users to load models without knowing exactly what are the supported frag types in a pretrained model. So this overrides the requested frag types and uses all supported frag types by the loaded model.

Safe to train: Any model is automatically set to 'safe to train' when the train function is called-- by modifying only the underlying model output head to align with the requested frag types. After training, the model will automatically be in a 'safe to predict' state.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@mo-sameh mo-sameh requested a review from GeorgWa February 12, 2025 18:51
Copy link
Collaborator

@GeorgWa GeorgWa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really good and clean solution for the problem.
I would suggest to refine the notebook and PR description a bit to make it more clear 1) what is the problem 2) what does this PR implement 3) why is this a solution to the problem. This could be sections or headings etc.

]
},
{
"cell_type": "markdown",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to make more clear, that we will first show the old format and how the old weight format has lead to issued:

So three examples:

  1. Request charged_frag_types as in weights and loading goes fine
  2. Request charged_frag_types with different shape and loading fails with weight missmatch
  3. Request charged_frag_types with same shape (but different types or order) and the model pretends it works but is acually broke (worst case).

"cell_type": "markdown",
"metadata": {},
"source": [
"## User importing weighst in the new format "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I would be more clear. Benefits of the new parameter format shown with different use-cases.

"model = pDeepModel(charged_frag_types=['b_z1', 'b_z2', 'y_z1', 'y_z2'])\n",
"model.load(new_path)\n",
"print(f\"Model Interface has charged_frag_types {model.charged_frag_types}\")\n",
"print(f\"Supported charged_frag_types in the loaded weights {model.model.supported_charged_frag_types}\")"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model.model interface feels a bit strange. could we improe the naming?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but changing the attribute name in the model interface would be a breaking change, for example in alphaDIA FinetuneManager
I changed the variable name used in the notebook to be more clear.

],
"source": [
"model = pDeepModel(\n",
" charged_frag_types=['c_z1', 'c_z2', 'y_z1', 'y_z2', 'x_z1', 'x_z2', 'y_modloss_z1', 'y_modloss_z2'], # Will be overridden by the model weights\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we leave the charged_frag_types empty to emphasize this?
charged_frag_types = []

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not completely empty since this initializes a pytorch model but I changed it to a single frag type and added print statements to emphasize this use case.

"output_type": "execute_result"
}
],
"source": [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an example that the order is automatically controlled by alphabase?

Even if the user requests a random order both the saved and predicted frag types are always in a defined order by sort_charged_frag_types.

**kwargs, # model params
):
super().__init__(device=device)
if mask_modloss is not None:
warnings.warn("mask_modloss is deprecated and will be removed in the future. To mask the modloss fragments, the charged_frag_types should not include the modloss fragments.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does mask_modloss still work in this version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and also is one of test cases

device : str, optional
Device to run the model, by default "gpu"
override_from_weights : bool, optional default False
Over ride the requested charged frag types when loading model from weights, this will
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make it more clear what this is doing.

Always update the requested charged fragment types from the model weights on loading. This allows to predict all fragment types supported by the weights even if the user doesn't know what fragments types are supported by the weights. Thereby, the model will always be in a safe to predict state.

@@ -504,6 +530,41 @@ def _set_batch_predict_data(
predicts.reshape((-1, len(self.charged_frag_types))),
batch_df[["frag_start_idx", "frag_stop_idx"]].values,
)
def _align_model_charged_frag_types(self):
"""
Align the underlying model charged_frag_types with the interface charged_frag_types,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are interface charged frag types the requested ones?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like the impact can be more clear. align sounds very wague.
Maybe something like reshape_prediction_head ?

@@ -587,6 +655,10 @@ def predict(
reference_frag_df=None,
**kwargs,
) -> pd.DataFrame:
if not self._safe_to_predict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please correct me, but shouldn't the most obvious solution be that users should select different (supported) charged frag types? I would mention this first and then using different weights and retraining.

f"nn parameters {unexpect_keys} are UNEXPECTED while loading models in {self.__class__}"
)

def _update_model_state(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this lead to deviation of _safe_to_predict and the reality?
Do you think it would make sense to do this as part of a @Property?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can, the only use case that's preventing that is:

  • Model loaded and is ( safe to predict and not safe to train).
  • reshape the prediction head called. So now both supported and requested frag types are aligned which makes the model technically safe to predict, and having the state update as part of the getters will allow prediction with a randomly initialized prediction head even if the training didn't start (ex. crashed).

Currently, the model goes from 'not safe to predict' to 'safe to predict' only after the training completes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see 👍🏻

@@ -884,3 +995,31 @@ def calc_ms2_similarity(

torch.cuda.empty_cache()
return psm_df, metrics_describ


def charged_frags_to_tensor(charged_frags: List[str]) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this looks very nice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants