diff --git a/multimolecule/__init__.py b/multimolecule/__init__.py index 5f606a89..8631fde9 100644 --- a/multimolecule/__init__.py +++ b/multimolecule/__init__.py @@ -8,6 +8,10 @@ UtrLmForCrisprOffTarget, ) from .models import ( + BaseHeadConfig, + HeadConfig, + MaskedLMHeadConfig, + PretrainedConfig, RnaBertConfig, RnaBertForMaskedLM, RnaBertForNucleotideClassification, @@ -54,6 +58,10 @@ from .tokenisers import RnaTokenizer __all__ = [ + "PretrainedConfig", + "BaseHeadConfig", + "HeadConfig", + "MaskedLMHeadConfig", "models", "tokenisers", "RnaTokenizer", diff --git a/multimolecule/downstream/crispr_off_target.py b/multimolecule/downstream/crispr_off_target.py index 55bad371..8aa7998b 100644 --- a/multimolecule/downstream/crispr_off_target.py +++ b/multimolecule/downstream/crispr_off_target.py @@ -7,13 +7,13 @@ from torch import Tensor from transformers.modeling_outputs import ModelOutput -from multimolecule.models.modeling_utils import ClassificationHead, Criterion from multimolecule.models.rnabert import RnaBertConfig, RnaBertModel, RnaBertPreTrainedModel from multimolecule.models.rnafm import RnaFmConfig, RnaFmModel, RnaFmPreTrainedModel from multimolecule.models.rnamsm import RnaMsmConfig, RnaMsmModel, RnaMsmPreTrainedModel from multimolecule.models.splicebert import SpliceBertConfig, SpliceBertModel, SpliceBertPreTrainedModel from multimolecule.models.utrbert import UtrBertConfig, UtrBertModel, UtrBertPreTrainedModel from multimolecule.models.utrlm import UtrLmConfig, UtrLmModel, UtrLmPreTrainedModel +from multimolecule.module import ClassificationHead, Criterion class RnaBertForCrisprOffTarget(RnaBertPreTrainedModel): diff --git a/multimolecule/models/__init__.py b/multimolecule/models/__init__.py index ab22212f..9b28057e 100644 --- a/multimolecule/models/__init__.py +++ b/multimolecule/models/__init__.py @@ -1,4 +1,5 @@ from ..tokenisers.rna import RnaTokenizer +from .configuration_utils import BaseHeadConfig, HeadConfig, MaskedLMHeadConfig, PretrainedConfig from .rnabert import ( RnaBertConfig, RnaBertForMaskedLM, @@ -55,6 +56,10 @@ ) __all__ = [ + "PretrainedConfig", + "HeadConfig", + "BaseHeadConfig", + "MaskedLMHeadConfig", "RnaTokenizer", "RnaBertConfig", "RnaBertModel", diff --git a/multimolecule/models/configuration_utils.py b/multimolecule/models/configuration_utils.py index c6e33793..99f5774c 100644 --- a/multimolecule/models/configuration_utils.py +++ b/multimolecule/models/configuration_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import OrderedDict from dataclasses import asdict, dataclass, is_dataclass from typing import Optional @@ -7,7 +8,7 @@ class PretrainedConfig(_PretrainedConfig): - head: HeadConfig + head: BaseHeadConfig def __init__( self, pad_token_id=0, bos_token_id=1, eos_token_id=2, unk_token_id=3, mask_token_id=4, null_token_id=5, **kwargs @@ -38,8 +39,12 @@ def to_dict(self): return output +class HeadConfig(OrderedDict): + pass + + @dataclass -class HeadConfig: +class BaseHeadConfig(HeadConfig): r""" This is the configuration class to store the configuration of a prediction head. It is used to instantiate a prediction head according to the specified arguments, defining the head architecture. @@ -82,7 +87,7 @@ class HeadConfig: @dataclass -class MaskedLMHeadConfig: +class MaskedLMHeadConfig(HeadConfig): r""" This is the configuration class to store the configuration of a prediction head. It is used to instantiate a prediction head according to the specified arguments, defining the head architecture. diff --git a/multimolecule/models/modeling_utils.py b/multimolecule/models/modeling_utils.py index 5b6edc52..b6f3112c 100644 --- a/multimolecule/models/modeling_utils.py +++ b/multimolecule/models/modeling_utils.py @@ -1,481 +1,6 @@ from __future__ import annotations -from functools import partial -from typing import Optional, Tuple - import torch -from chanfig import ConfigRegistry -from torch import Tensor, nn -from torch.nn import functional as F -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ModelOutput - -from .configuration_utils import HeadConfig, PretrainedConfig - -TokenHeads = ConfigRegistry(key="tokenizer_type") -NucleotideHeads = ConfigRegistry(key="tokenizer_type") - - -class ContactPredictionHead(nn.Module): - """ - Head for contact-map-level tasks. - Performs symmetrization, and average product correct. - """ - - def __init__(self, config: PretrainedConfig): - super().__init__() - self.config = config.head - if self.config.hidden_size is None: - self.config.hidden_size = config.hidden_size - if self.config.num_labels is None: - self.config.num_labels = config.num_labels - if self.config.problem_type is None: - self.config.problem_type = config.problem_type - self.num_labels = self.config.num_labels - self.bos_token_id = config.bos_token_id - self.eos_token_id = config.eos_token_id - self.pad_token_id = config.pad_token_id - self.dropout = nn.Dropout(self.config.dropout) - self.transform = PredictionHeadTransform.build(self.config) - self.decoder = nn.Linear( - config.num_hidden_layers * config.num_attention_heads, self.num_labels, bias=self.config.bias - ) - self.activation = ACT2FN[self.config.act] if self.config.act is not None else None - - def forward( - self, attentions: Tensor, attention_mask: Optional[Tensor] = None, input_ids: Optional[Tensor] = None - ) -> Tensor: - if attention_mask is None: - if input_ids is None: - raise ValueError( - "Either attention_mask or input_ids must be provided for ContactPredictionHead to work." - ) - if self.pad_token_id is None: - raise ValueError( - "pad_token_id must be provided when attention_mask is not passed to ContactPredictionHead." - ) - attention_mask = input_ids.ne(self.pad_token_id) - # In the original model, attentions for padding tokens are completely zeroed out. - # This makes no difference most of the time because the other tokens won't attend to them, - # but it does for the contact prediction task, which takes attentions as input, - # so we have to mimic that here. - attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) - attentions *= attention_mask[:, None, None, :, :] - # remove cls token attentions - if self.bos_token_id is not None: - attentions = attentions[..., 1:, 1:] - attention_mask = attention_mask[..., 1:, 1:] - if input_ids is not None: - input_ids = input_ids[..., 1:] - # remove eos token attentions - if self.eos_token_id is not None: - if input_ids is not None: - eos_mask = input_ids.ne(self.eos_token_id).to(attentions) - input_ids = input_ids[..., 1:] - else: - last_valid_indices = attention_mask.sum(dim=-1) - seq_length = attention_mask.size(-1) - eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) == last_valid_indices - eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) - attentions *= eos_mask[:, None, None, :, :] - attentions = attentions[..., :-1, :-1] - attention_mask = attention_mask[..., 1:, 1:] - - # features: batch x channels x input_ids x input_ids (symmetric) - batch_size, layers, heads, seqlen, _ = attentions.size() - attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) - attentions = attentions.to(self.decoder.weight.device) - attentions = average_product_correct(symmetrize(attentions)) - attentions = attentions.permute(0, 2, 3, 1) - output = self.dropout(attentions) - output = self.decoder(output).squeeze(3) - if self.activation is not None: - output = self.activation(output) - return output - - -class MaskedLMHead(nn.Module): - """Head for masked language modeling.""" - - def __init__(self, config: PretrainedConfig, weight: Optional[Tensor] = None): - super().__init__() - self.config = config.lm_head if hasattr(config, "lm_head") else config.head - if self.config.hidden_size is None: - self.config.hidden_size = config.hidden_size - self.num_labels = config.vocab_size - self.dropout = nn.Dropout(self.config.dropout) - self.transform = PredictionHeadTransform.build(self.config) - self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=False) - if weight is not None: - self.decoder.weight = weight - if self.config.bias: - self.bias = nn.Parameter(torch.zeros(self.num_labels)) - self.decoder.bias = self.bias - self.activation = ACT2FN[self.config.act] if self.config.act is not None else None - - def forward(self, outputs: ModelOutput | Tuple[Tensor, ...]) -> Tensor: - sequence_output = outputs[0] - output = self.dropout(sequence_output) - output = self.transform(output) - output = self.decoder(output) - if self.activation is not None: - output = self.activation(output) - return output - - -class ClassificationHead(nn.Module): - """Head for all-level of tasks.""" - - num_labels: int - - def __init__(self, config: PretrainedConfig): - super().__init__() - self.config = config.head - if self.config.hidden_size is None: - self.config.hidden_size = config.hidden_size - if self.config.num_labels is None: - self.config.num_labels = config.num_labels - if self.config.problem_type is None: - self.config.problem_type = config.problem_type - self.num_labels = self.config.num_labels - self.dropout = nn.Dropout(self.config.dropout) - self.transform = PredictionHeadTransform.build(self.config) - self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias) - self.activation = ACT2FN[self.config.act] if self.config.act is not None else None - - def forward(self, embeddings: Tensor) -> Tensor: - output = self.dropout(embeddings) - output = self.transform(output) - output = self.decoder(output) - if self.activation is not None: - output = self.activation(output) - return output - - -class SequenceClassificationHead(ClassificationHead): - """Head for sequence-level tasks.""" - - def forward(self, outputs: ModelOutput | Tuple[Tensor, ...]) -> Tensor: # pylint: disable=arguments-renamed - output = super().forward(outputs[1]) - return output - - -@TokenHeads.register("single", default=True) -class TokenClassificationHead(ClassificationHead): - """Head for token-level tasks.""" - - def forward(self, outputs: ModelOutput | Tuple[Tensor, ...]) -> Tensor: # pylint: disable=arguments-renamed - output = super().forward(outputs[0]) - return output - - -@TokenHeads.register("kmer") -class TokenKMerHead(ClassificationHead): - """Head for token-level tasks.""" - - def __init__(self, config: PretrainedConfig): - super().__init__(config) - self.nmers = config.nmers - self.bos_token_id = config.bos_token_id - self.eos_token_id = config.eos_token_id - self.pad_token_id = config.pad_token_id - self.unfold_kmer_embeddings = partial( - unfold_kmer_embeddings, nmers=self.nmers, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id - ) - - def forward( # pylint: disable=arguments-renamed - self, - outputs: ModelOutput | Tuple[Tensor, ...], - attention_mask: Optional[Tensor] = None, - input_ids: Optional[Tensor] = None, - ) -> Tensor: - if attention_mask is None: - if input_ids is None: - raise ValueError("Either attention_mask or input_ids must be provided for TokenKMerHead to work.") - if self.pad_token_id is None: - raise ValueError("pad_token_id must be provided when attention_mask is not passed to TokenKMerHead.") - attention_mask = input_ids.ne(self.pad_token_id) - - output = outputs[0] - output = self.unfold_kmer_embeddings(output, attention_mask) - output = super().forward(output) - return output - - -@NucleotideHeads.register("single", default=True) -class NucleotideClassificationHead(ClassificationHead): - """Head for nucleotide-level tasks.""" - - def __init__(self, config: PretrainedConfig): - super().__init__(config) - self.bos_token_id = config.bos_token_id - self.eos_token_id = config.eos_token_id - self.pad_token_id = config.pad_token_id - - def forward( # pylint: disable=arguments-renamed - self, - outputs: ModelOutput | Tuple[Tensor, ...], - attention_mask: Optional[Tensor] = None, - input_ids: Optional[Tensor] = None, - ) -> Tensor: - if attention_mask is None: - if input_ids is None: - raise ValueError( - "Either attention_mask or input_ids must be provided for NucleotideClassificationHead to work." - ) - if self.pad_token_id is None: - raise ValueError( - "pad_token_id must be provided when attention_mask is not passed to NucleotideClassificationHead." - ) - attention_mask = input_ids.ne(self.pad_token_id) - - output = outputs[0] - # remove cls token embeddings - if self.bos_token_id is not None: - output = output[..., 1:, :] - attention_mask = attention_mask[..., 1:] - if input_ids is not None: - input_ids = input_ids[..., 1:] - # remove eos token embeddings - if self.eos_token_id is not None: - if input_ids is not None: - eos_mask = input_ids.ne(self.eos_token_id).to(output) - input_ids = input_ids[..., 1:] - else: - last_valid_indices = attention_mask.sum(dim=-1) - seq_length = attention_mask.size(-1) - eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1) - output *= eos_mask[:, :, None] - output = output[..., :-1, :] - attention_mask = attention_mask[..., 1:] - - output = super().forward(output) - return output - - -@NucleotideHeads.register("kmer") -class NucleotideKMerHead(ClassificationHead): - """Head for nucleotide-level tasks.""" - - def __init__(self, config: PretrainedConfig): - super().__init__(config) - self.nmers = config.nmers - self.bos_token_id = None # Nucleotide-level head removes token. - self.eos_token_id = None # Nucleotide-level head removes token. - self.pad_token_id = config.pad_token_id - self.unfold_kmer_embeddings = partial(unfold_kmer_embeddings, nmers=self.nmers) - - def forward( # pylint: disable=arguments-renamed - self, - outputs: ModelOutput | Tuple[Tensor, ...], - attention_mask: Optional[Tensor] = None, - input_ids: Optional[Tensor] = None, - ) -> Tensor: - if attention_mask is None: - if input_ids is None: - raise ValueError("Either attention_mask or input_ids must be provided for NucleotideKMerHead to work.") - if self.pad_token_id is None: - raise ValueError( - "pad_token_id must be provided when attention_mask is not passed to NucleotideKMerHead." - ) - attention_mask = input_ids.ne(self.pad_token_id) - - output = outputs[0] - # remove cls token embeddings - if self.bos_token_id is not None: - output = output[..., 1:, :] - attention_mask = attention_mask[..., 1:] - if input_ids is not None: - input_ids = input_ids[..., 1:] - # remove eos token embeddings - if self.eos_token_id is not None: - if input_ids is not None: - eos_mask = input_ids.ne(self.eos_token_id).to(output) - input_ids = input_ids[..., 1:] - else: - last_valid_indices = attention_mask.sum(dim=-1) - seq_length = attention_mask.size(-1) - eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1) - output *= eos_mask[:, :, None] - output = output[..., :-1, :] - attention_mask = attention_mask[..., 1:] - - output = self.unfold_kmer_embeddings(output, attention_mask) - output = super().forward(output) - return output - - -class Criterion(nn.Module): - - problem_types = ["regression", "single_label_classification", "multi_label_classification"] - - def __init__(self, config: PretrainedConfig) -> None: - super().__init__() - self.config = config.head - self.problem_type = self.config.problem_type - self.num_labels = self.config.num_labels - - def forward(self, logits, labels) -> Tensor | None: - if labels is None: - return None - if self.problem_type is None: - if self.num_labels == 1: - self.problem_type = "regression" - elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int): - self.problem_type = "single_label_classification" - else: - self.problem_type = "multi_label_classification" - self.config.problem_type = self.problem_type - if self.problem_type == "regression": - return ( - F.mse_loss(logits.squeeze(), labels.squeeze()) if self.num_labels == 1 else F.mse_loss(logits, labels) - ) - if self.problem_type == "single_label_classification": - return F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) - if self.problem_type == "multi_label_classification": - return F.binary_cross_entropy_with_logits(logits, labels) - raise ValueError(f"problem_type should be one of {self.problem_types}, but got {self.problem_type}") - - -PredictionHeadTransform = ConfigRegistry(key="transform") - - -@PredictionHeadTransform.register("nonlinear") -class NonLinearTransform(nn.Module): - def __init__(self, config: HeadConfig): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - if isinstance(config.transform_act, str): - self.transform_act_fn = ACT2FN[config.transform_act] - else: - self.transform_act_fn = config.transform_act - self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states: Tensor) -> Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.layer_norm(hidden_states) - return hidden_states - - -@PredictionHeadTransform.register("linear") -class LinearTransform(nn.Module): - def __init__(self, config: HeadConfig): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) - self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states: Tensor) -> Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.layer_norm(hidden_states) - return hidden_states - - -@PredictionHeadTransform.register(None) -class IdentityTransform(nn.Identity): - def __init__(self, config: HeadConfig): # pylint: disable=unused-argument - super().__init__() - - -def unfold_kmer_embeddings( - embeddings: Tensor, - attention_mask: Tensor, - nmers: int, - bos_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, -) -> Tensor: - r""" - Unfold k-mer embeddings to token embeddings. - - For k-mer input, each embedding column represents k tokens. - This should be fine for sequence level tasks, but sacrifices the resolution for token level tasks. - This function unfolds the k-mer embeddings to token embeddings by sliding averaging the k-mer embeddings. - - For example: - - input tokens = `ACGU` - - 2-mer embeddings = `[, AC, CG, GU, ]`. - - token embeddings = `[, AC, (AC + CG) / 2, (CG + GU) / 2, GU, ]`. - - Args: - embeddings: The k-mer embeddings. - attention_mask: The attention mask. - nmers: The number of tokens in each k-mer. - bos_token_id: The id of the beginning of sequence token. - If not None, the first valid token will not be included in sliding averaging. - eos_token_id: The id of the end of sequence token. - If not None, the last valid token will not be included in sliding averaging. - - Returns: - The token embeddings. - - Examples: - >>> from danling import NestedTensor - >>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(5).repeat(2, 1).T) + 1 - >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 3, True, True) - >>> output[0, :, 0].tolist() - [1.0, 2.0, 2.0, 2.0, 3.0, 0.0, 0.0] - >>> output[1, :, 0].tolist() - [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0] - >>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1 - >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 4, True, True) - >>> output[0, :, 0].tolist() - [1.0, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0] - >>> output[1, :, 0].tolist() - [1.0, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 7.0] - >>> embeddings = NestedTensor(torch.arange(7).repeat(2, 1).T, torch.arange(11).repeat(2, 1).T) + 1 - >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 5, True, True) - >>> output[0, :, 0].tolist() - [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0] - >>> output[1, :, 0].tolist() - [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0, 7.0, 8.0, 8.5, 9.0, 9.5, 10.0, 11.0] - >>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(4).repeat(2, 1).T) + 1 - >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6, True, True) - >>> output[0, :, 0].tolist() - [1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.0] - >>> output[1, :, 0].tolist() - [1.0, 2.0, 2.5, 2.5, 2.5, 2.5, 2.5, 3.0, 4.0] - >>> embeddings = NestedTensor(torch.arange(1).repeat(2, 1).T, torch.arange(2).repeat(2, 1).T) + 1 - >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6) - >>> output[0, :, 0].tolist() - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0] - >>> output[1, :, 0].tolist() - [1.0, 1.5, 1.5, 1.5, 1.5, 1.5, 2.0] - """ - - batch_size, seq_length, hidden_size = embeddings.size() - last_valid_indices = attention_mask.sum(dim=-1) - output = torch.zeros(batch_size, seq_length + nmers - 1, hidden_size, device=embeddings.device) - for index, (tensor, seq_len) in enumerate(zip(embeddings, last_valid_indices)): - embedding = tensor[:seq_len] - if bos_token_id is not None: - embedding = embedding[1:] - if eos_token_id is not None: - embedding = embedding[:-1] - if len(embedding) > nmers: - begin = torch.stack([embedding[:i].mean(0) for i in range(1, nmers)]) - medium = embedding.unfold(0, nmers, 1).mean(-1) - end = torch.stack([embedding[-i:].mean(0) for i in range(nmers - 1, 0, -1)]) - embedding = torch.cat([begin, medium, end]) - elif len(embedding) > 2: - begin = torch.stack([embedding[:i].mean(0) for i in range(1, len(embedding))]) - end = torch.stack([embedding[-i:].mean(0) for i in range(nmers, 0, -1)]) - embedding = torch.cat([begin, end]) - elif len(embedding) == 2: - medium = embedding.mean(0).repeat(nmers - 1, 1) - embedding = torch.cat([embedding[0][None, :], medium, embedding[1][None, :]]) - elif len(embedding) == 1: - embedding = embedding.repeat(nmers, 1) - else: - raise ValueError("Sequence length is less than nmers.") - if bos_token_id is not None: - embedding = torch.cat([tensor[0][None, :], embedding]) - if eos_token_id is not None: - embedding = torch.cat([embedding, tensor[seq_len - 1][None, :]]) - output[index, : seq_len + nmers - 1] = embedding - return output def rotate_half(x): @@ -488,20 +13,3 @@ def apply_rotary_pos_emb(x, cos, sin): sin = sin[:, :, : x.shape[-2], :] return (x * cos) + (rotate_half(x) * sin) - - -def symmetrize(x): - "Make layer symmetric in final two dimensions, used for contact prediction." - return x + x.transpose(-1, -2) - - -def average_product_correct(x): - "Perform average product correct, used for contact prediction." - a1 = x.sum(-1, keepdims=True) - a2 = x.sum(-2, keepdims=True) - a12 = x.sum((-1, -2), keepdims=True) - - avg = a1 * a2 - avg.div_(a12) # in-place to reduce memory - normalized = x - avg - return normalized diff --git a/multimolecule/models/rnabert/configuration_rnabert.py b/multimolecule/models/rnabert/configuration_rnabert.py index 4d22e72e..71d11571 100644 --- a/multimolecule/models/rnabert/configuration_rnabert.py +++ b/multimolecule/models/rnabert/configuration_rnabert.py @@ -1,6 +1,6 @@ from transformers.utils import logging -from ..configuration_utils import HeadConfig, MaskedLMHeadConfig, PretrainedConfig +from ..configuration_utils import BaseHeadConfig, MaskedLMHeadConfig, PretrainedConfig logger = logging.get_logger(__name__) @@ -95,5 +95,5 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.use_cache = use_cache - self.head = HeadConfig(**head if head is not None else {}) + self.head = BaseHeadConfig(**head if head is not None else {}) self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py index aa768be0..3ca3020f 100644 --- a/multimolecule/models/rnabert/modeling_rnabert.py +++ b/multimolecule/models/rnabert/modeling_rnabert.py @@ -19,13 +19,14 @@ TokenClassifierOutput, ) -from ..modeling_utils import ( +from multimolecule.module import ( Criterion, MaskedLMHead, NucleotideClassificationHead, SequenceClassificationHead, TokenClassificationHead, ) + from .configuration_rnabert import RnaBertConfig diff --git a/multimolecule/models/rnafm/configuration_rnafm.py b/multimolecule/models/rnafm/configuration_rnafm.py index d372a7ea..de58c9c2 100644 --- a/multimolecule/models/rnafm/configuration_rnafm.py +++ b/multimolecule/models/rnafm/configuration_rnafm.py @@ -1,6 +1,6 @@ from transformers.utils import logging -from ..configuration_utils import HeadConfig, MaskedLMHeadConfig, PretrainedConfig +from ..configuration_utils import BaseHeadConfig, MaskedLMHeadConfig, PretrainedConfig logger = logging.get_logger(__name__) @@ -108,5 +108,5 @@ def __init__( self.use_cache = use_cache self.emb_layer_norm_before = emb_layer_norm_before self.token_dropout = token_dropout - self.head = HeadConfig(**head if head is not None else {}) + self.head = BaseHeadConfig(**head if head is not None else {}) self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) diff --git a/multimolecule/models/rnafm/modeling_rnafm.py b/multimolecule/models/rnafm/modeling_rnafm.py index 94205db9..673e5ce9 100755 --- a/multimolecule/models/rnafm/modeling_rnafm.py +++ b/multimolecule/models/rnafm/modeling_rnafm.py @@ -20,15 +20,16 @@ from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from transformers.utils import logging -from ..modeling_utils import ( +from multimolecule.module import ( ContactPredictionHead, Criterion, MaskedLMHead, NucleotideClassificationHead, SequenceClassificationHead, TokenClassificationHead, - apply_rotary_pos_emb, ) + +from ..modeling_utils import apply_rotary_pos_emb from .configuration_rnafm import RnaFmConfig logger = logging.get_logger(__name__) diff --git a/multimolecule/models/rnamsm/configuration_rnamsm.py b/multimolecule/models/rnamsm/configuration_rnamsm.py index 2232e8ac..aae090f9 100644 --- a/multimolecule/models/rnamsm/configuration_rnamsm.py +++ b/multimolecule/models/rnamsm/configuration_rnamsm.py @@ -1,6 +1,6 @@ from transformers.utils import logging -from ..configuration_utils import HeadConfig, MaskedLMHeadConfig, PretrainedConfig +from ..configuration_utils import BaseHeadConfig, MaskedLMHeadConfig, PretrainedConfig logger = logging.get_logger(__name__) @@ -97,5 +97,5 @@ def __init__( self.attention_type = attention_type self.embed_positions_msa = embed_positions_msa self.attention_bias = attention_bias - self.head = HeadConfig(**head if head is not None else {}) + self.head = BaseHeadConfig(**head if head is not None else {}) self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) diff --git a/multimolecule/models/rnamsm/modeling_rnamsm.py b/multimolecule/models/rnamsm/modeling_rnamsm.py index 44d7f62b..f9ce9ed8 100644 --- a/multimolecule/models/rnamsm/modeling_rnamsm.py +++ b/multimolecule/models/rnamsm/modeling_rnamsm.py @@ -14,7 +14,7 @@ from transformers.activations import ACT2FN from transformers.modeling_outputs import ModelOutput -from ..modeling_utils import ( +from multimolecule.module import ( ContactPredictionHead, Criterion, MaskedLMHead, @@ -22,6 +22,7 @@ SequenceClassificationHead, TokenClassificationHead, ) + from .configuration_rnamsm import RnaMsmConfig diff --git a/multimolecule/models/splicebert/configuration_splicebert.py b/multimolecule/models/splicebert/configuration_splicebert.py index 6b70357d..79452260 100644 --- a/multimolecule/models/splicebert/configuration_splicebert.py +++ b/multimolecule/models/splicebert/configuration_splicebert.py @@ -1,6 +1,6 @@ from transformers.utils import logging -from ..configuration_utils import HeadConfig, MaskedLMHeadConfig, PretrainedConfig +from ..configuration_utils import BaseHeadConfig, MaskedLMHeadConfig, PretrainedConfig logger = logging.get_logger(__name__) @@ -90,5 +90,5 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.use_cache = use_cache - self.head = HeadConfig(**head if head is not None else {}) + self.head = BaseHeadConfig(**head if head is not None else {}) self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) diff --git a/multimolecule/models/splicebert/modeling_splicebert.py b/multimolecule/models/splicebert/modeling_splicebert.py index 8fd463b0..5fe1bf44 100644 --- a/multimolecule/models/splicebert/modeling_splicebert.py +++ b/multimolecule/models/splicebert/modeling_splicebert.py @@ -20,13 +20,14 @@ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from transformers.utils import logging -from ..modeling_utils import ( +from multimolecule.module import ( Criterion, MaskedLMHead, NucleotideClassificationHead, SequenceClassificationHead, TokenClassificationHead, ) + from .configuration_splicebert import SpliceBertConfig try: diff --git a/multimolecule/models/utrbert/configuration_utrbert.py b/multimolecule/models/utrbert/configuration_utrbert.py index 854e193c..f5107d81 100644 --- a/multimolecule/models/utrbert/configuration_utrbert.py +++ b/multimolecule/models/utrbert/configuration_utrbert.py @@ -1,6 +1,6 @@ from transformers.utils import logging -from ..configuration_utils import HeadConfig, MaskedLMHeadConfig, PretrainedConfig +from ..configuration_utils import BaseHeadConfig, MaskedLMHeadConfig, PretrainedConfig logger = logging.get_logger(__name__) @@ -110,5 +110,5 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.use_cache = use_cache - self.head = HeadConfig(**head if head is not None else {}) + self.head = BaseHeadConfig(**head if head is not None else {}) self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) diff --git a/multimolecule/models/utrbert/modeling_utrbert.py b/multimolecule/models/utrbert/modeling_utrbert.py index 09af8bfb..cbf15194 100644 --- a/multimolecule/models/utrbert/modeling_utrbert.py +++ b/multimolecule/models/utrbert/modeling_utrbert.py @@ -20,7 +20,8 @@ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from transformers.utils import logging -from ..modeling_utils import Criterion, MaskedLMHead, NucleotideKMerHead, SequenceClassificationHead, TokenKMerHead +from multimolecule.module import Criterion, MaskedLMHead, NucleotideKMerHead, SequenceClassificationHead, TokenKMerHead + from .configuration_utrbert import UtrBertConfig logger = logging.get_logger(__name__) diff --git a/multimolecule/models/utrlm/configuration_utrlm.py b/multimolecule/models/utrlm/configuration_utrlm.py index b965b7c1..8608a77e 100644 --- a/multimolecule/models/utrlm/configuration_utrlm.py +++ b/multimolecule/models/utrlm/configuration_utrlm.py @@ -1,6 +1,6 @@ from transformers.utils import logging -from ..configuration_utils import HeadConfig, MaskedLMHeadConfig, PretrainedConfig +from ..configuration_utils import BaseHeadConfig, MaskedLMHeadConfig, PretrainedConfig logger = logging.get_logger(__name__) @@ -119,7 +119,7 @@ def __init__( self.use_cache = use_cache self.emb_layer_norm_before = emb_layer_norm_before self.token_dropout = token_dropout - self.head = HeadConfig(**head if head is not None else {}) + self.head = BaseHeadConfig(**head if head is not None else {}) self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) - self.structure_head = HeadConfig(**structure_head) if structure_head is not None else None - self.supervised_head = HeadConfig(**supervised_head) if supervised_head is not None else None + self.structure_head = BaseHeadConfig(**structure_head) if structure_head is not None else None + self.supervised_head = BaseHeadConfig(**supervised_head) if supervised_head is not None else None diff --git a/multimolecule/models/utrlm/modeling_utrlm.py b/multimolecule/models/utrlm/modeling_utrlm.py index 59ff734f..09fd31d3 100755 --- a/multimolecule/models/utrlm/modeling_utrlm.py +++ b/multimolecule/models/utrlm/modeling_utrlm.py @@ -20,15 +20,16 @@ from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from transformers.utils import logging -from ..modeling_utils import ( +from multimolecule.module import ( ContactPredictionHead, Criterion, MaskedLMHead, NucleotideClassificationHead, SequenceClassificationHead, TokenClassificationHead, - apply_rotary_pos_emb, ) + +from ..modeling_utils import apply_rotary_pos_emb from .configuration_utrlm import UtrLmConfig logger = logging.get_logger(__name__) diff --git a/multimolecule/module/__init__.py b/multimolecule/module/__init__.py new file mode 100644 index 00000000..d11b0332 --- /dev/null +++ b/multimolecule/module/__init__.py @@ -0,0 +1,35 @@ +from .critersions import Criterion +from .heads import ( + ClassificationHead, + ContactPredictionHead, + HeadTransforms, + IdentityTransform, + LinearTransform, + MaskedLMHead, + NonLinearTransform, + NucleotideClassificationHead, + NucleotideHeads, + NucleotideKMerHead, + SequenceClassificationHead, + TokenClassificationHead, + TokenHeads, + TokenKMerHead, +) + +__all__ = [ + "ClassificationHead", + "SequenceClassificationHead", + "TokenHeads", + "TokenClassificationHead", + "TokenKMerHead", + "NucleotideHeads", + "NucleotideClassificationHead", + "NucleotideKMerHead", + "ContactPredictionHead", + "MaskedLMHead", + "HeadTransforms", + "LinearTransform", + "NonLinearTransform", + "IdentityTransform", + "Criterion", +] diff --git a/multimolecule/module/critersions/__init__.py b/multimolecule/module/critersions/__init__.py new file mode 100644 index 00000000..9e9c19f4 --- /dev/null +++ b/multimolecule/module/critersions/__init__.py @@ -0,0 +1,3 @@ +from .generic import Criterion + +__all__ = ["Criterion"] diff --git a/multimolecule/module/critersions/generic.py b/multimolecule/module/critersions/generic.py new file mode 100644 index 00000000..dff1da05 --- /dev/null +++ b/multimolecule/module/critersions/generic.py @@ -0,0 +1,36 @@ +import torch +from torch import Tensor, nn +from torch.nn import functional as F +from transformers import PretrainedConfig + + +class Criterion(nn.Module): + + problem_types = ["regression", "single_label_classification", "multi_label_classification"] + + def __init__(self, config: PretrainedConfig) -> None: + super().__init__() + self.config = config.head + self.problem_type = self.config.problem_type + self.num_labels = self.config.num_labels + + def forward(self, logits, labels) -> Tensor | None: + if labels is None: + return None + if self.problem_type is None: + if self.num_labels == 1: + self.problem_type = "regression" + elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int): + self.problem_type = "single_label_classification" + else: + self.problem_type = "multi_label_classification" + self.config.problem_type = self.problem_type + if self.problem_type == "regression": + return ( + F.mse_loss(logits.squeeze(), labels.squeeze()) if self.num_labels == 1 else F.mse_loss(logits, labels) + ) + if self.problem_type == "single_label_classification": + return F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) + if self.problem_type == "multi_label_classification": + return F.binary_cross_entropy_with_logits(logits, labels) + raise ValueError(f"problem_type should be one of {self.problem_types}, but got {self.problem_type}") diff --git a/multimolecule/module/heads/__init__.py b/multimolecule/module/heads/__init__.py new file mode 100644 index 00000000..c30fd30f --- /dev/null +++ b/multimolecule/module/heads/__init__.py @@ -0,0 +1,24 @@ +from .contact import ContactPredictionHead +from .generic import ClassificationHead +from .nuleotide import NucleotideClassificationHead, NucleotideHeads, NucleotideKMerHead +from .pretrain import MaskedLMHead +from .sequence import SequenceClassificationHead +from .token import TokenClassificationHead, TokenHeads, TokenKMerHead +from .transform import HeadTransforms, IdentityTransform, LinearTransform, NonLinearTransform + +__all__ = [ + "ClassificationHead", + "SequenceClassificationHead", + "TokenHeads", + "TokenClassificationHead", + "TokenKMerHead", + "NucleotideHeads", + "NucleotideClassificationHead", + "NucleotideKMerHead", + "ContactPredictionHead", + "MaskedLMHead", + "HeadTransforms", + "LinearTransform", + "NonLinearTransform", + "IdentityTransform", +] diff --git a/multimolecule/module/heads/contact.py b/multimolecule/module/heads/contact.py new file mode 100644 index 00000000..22fe7869 --- /dev/null +++ b/multimolecule/module/heads/contact.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import torch +from torch import Tensor, nn +from transformers.activations import ACT2FN + +from multimolecule.models.configuration_utils import PretrainedConfig + +from .transform import HeadTransforms +from .utils import average_product_correct, symmetrize + + +class ContactPredictionHead(nn.Module): + """ + Head for contact-map-level tasks. + Performs symmetrization, and average product correct. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config.head + if self.config.hidden_size is None: + self.config.hidden_size = config.hidden_size + if self.config.num_labels is None: + self.config.num_labels = config.num_labels + if self.config.problem_type is None: + self.config.problem_type = config.problem_type + self.num_labels = self.config.num_labels + self.bos_token_id = config.bos_token_id + self.eos_token_id = config.eos_token_id + self.pad_token_id = config.pad_token_id + self.dropout = nn.Dropout(self.config.dropout) + self.transform = HeadTransforms.build(self.config) + self.decoder = nn.Linear( + config.num_hidden_layers * config.num_attention_heads, self.num_labels, bias=self.config.bias + ) + self.activation = ACT2FN[self.config.act] if self.config.act is not None else None + + def forward( + self, attentions: Tensor, attention_mask: Tensor | None = None, input_ids: Tensor | None = None + ) -> Tensor: + if attention_mask is None: + if input_ids is None: + raise ValueError( + "Either attention_mask or input_ids must be provided for ContactPredictionHead to work." + ) + if self.pad_token_id is None: + raise ValueError( + "pad_token_id must be provided when attention_mask is not passed to ContactPredictionHead." + ) + attention_mask = input_ids.ne(self.pad_token_id) + # In the original model, attentions for padding tokens are completely zeroed out. + # This makes no difference most of the time because the other tokens won't attend to them, + # but it does for the contact prediction task, which takes attentions as input, + # so we have to mimic that here. + attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) + attentions *= attention_mask[:, None, None, :, :] + # remove cls token attentions + if self.bos_token_id is not None: + attentions = attentions[..., 1:, 1:] + attention_mask = attention_mask[..., 1:, 1:] + if input_ids is not None: + input_ids = input_ids[..., 1:] + # remove eos token attentions + if self.eos_token_id is not None: + if input_ids is not None: + eos_mask = input_ids.ne(self.eos_token_id).to(attentions) + input_ids = input_ids[..., 1:] + else: + last_valid_indices = attention_mask.sum(dim=-1) + seq_length = attention_mask.size(-1) + eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) == last_valid_indices + eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) + attentions *= eos_mask[:, None, None, :, :] + attentions = attentions[..., :-1, :-1] + attention_mask = attention_mask[..., 1:, 1:] + + # features: batch x channels x input_ids x input_ids (symmetric) + batch_size, layers, heads, seqlen, _ = attentions.size() + attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) + attentions = attentions.to(self.decoder.weight.device) + attentions = average_product_correct(symmetrize(attentions)) + attentions = attentions.permute(0, 2, 3, 1) + output = self.dropout(attentions) + output = self.decoder(output).squeeze(3) + if self.activation is not None: + output = self.activation(output) + return output diff --git a/multimolecule/module/heads/generic.py b/multimolecule/module/heads/generic.py new file mode 100644 index 00000000..b80b8e5e --- /dev/null +++ b/multimolecule/module/heads/generic.py @@ -0,0 +1,35 @@ +from torch import Tensor, nn +from transformers.activations import ACT2FN + +from multimolecule.models.configuration_utils import PretrainedConfig + +from .transform import HeadTransforms + + +class ClassificationHead(nn.Module): + """Head for all-level of tasks.""" + + num_labels: int + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config.head + if self.config.hidden_size is None: + self.config.hidden_size = config.hidden_size + if self.config.num_labels is None: + self.config.num_labels = config.num_labels + if self.config.problem_type is None: + self.config.problem_type = config.problem_type + self.num_labels = self.config.num_labels + self.dropout = nn.Dropout(self.config.dropout) + self.transform = HeadTransforms.build(self.config) + self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias) + self.activation = ACT2FN[self.config.act] if self.config.act is not None else None + + def forward(self, embeddings: Tensor) -> Tensor: + output = self.dropout(embeddings) + output = self.transform(output) + output = self.decoder(output) + if self.activation is not None: + output = self.activation(output) + return output diff --git a/multimolecule/module/heads/nuleotide.py b/multimolecule/module/heads/nuleotide.py new file mode 100644 index 00000000..d8e455e0 --- /dev/null +++ b/multimolecule/module/heads/nuleotide.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from functools import partial +from typing import Tuple + +import torch +from chanfig import ConfigRegistry +from torch import Tensor, nn +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ModelOutput + +from multimolecule.models.configuration_utils import PretrainedConfig + +from .generic import ClassificationHead +from .utils import unfold_kmer_embeddings + +NucleotideHeads = ConfigRegistry(key="tokenizer_type") + + +@NucleotideHeads.register("single", default=True) +class NucleotideClassificationHead(ClassificationHead): + """Head for nucleotide-level tasks.""" + + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.bos_token_id = config.bos_token_id + self.eos_token_id = config.eos_token_id + self.pad_token_id = config.pad_token_id + + def forward( # pylint: disable=arguments-renamed + self, + outputs: ModelOutput | Tuple[Tensor, ...], + attention_mask: Tensor | None = None, + input_ids: Tensor | None = None, + ) -> Tensor: + if attention_mask is None: + if input_ids is None: + raise ValueError( + "Either attention_mask or input_ids must be provided for NucleotideClassificationHead to work." + ) + if self.pad_token_id is None: + raise ValueError( + "pad_token_id must be provided when attention_mask is not passed to NucleotideClassificationHead." + ) + attention_mask = input_ids.ne(self.pad_token_id) + + output = outputs[0] + # remove cls token embeddings + if self.bos_token_id is not None: + output = output[..., 1:, :] + attention_mask = attention_mask[..., 1:] + if input_ids is not None: + input_ids = input_ids[..., 1:] + # remove eos token embeddings + if self.eos_token_id is not None: + if input_ids is not None: + eos_mask = input_ids.ne(self.eos_token_id).to(output) + input_ids = input_ids[..., 1:] + else: + last_valid_indices = attention_mask.sum(dim=-1) + seq_length = attention_mask.size(-1) + eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1) + output *= eos_mask[:, :, None] + output = output[..., :-1, :] + attention_mask = attention_mask[..., 1:] + + output = super().forward(output) + return output + + +@NucleotideHeads.register("kmer") +class NucleotideKMerHead(ClassificationHead): + """Head for nucleotide-level tasks.""" + + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.nmers = config.nmers + self.bos_token_id = None # Nucleotide-level head removes token. + self.eos_token_id = None # Nucleotide-level head removes token. + self.pad_token_id = config.pad_token_id + self.unfold_kmer_embeddings = partial(unfold_kmer_embeddings, nmers=self.nmers) + + def forward( # pylint: disable=arguments-renamed + self, + outputs: ModelOutput | Tuple[Tensor, ...], + attention_mask: Tensor | None = None, + input_ids: Tensor | None = None, + ) -> Tensor: + if attention_mask is None: + if input_ids is None: + raise ValueError("Either attention_mask or input_ids must be provided for NucleotideKMerHead to work.") + if self.pad_token_id is None: + raise ValueError( + "pad_token_id must be provided when attention_mask is not passed to NucleotideKMerHead." + ) + attention_mask = input_ids.ne(self.pad_token_id) + + output = outputs[0] + # remove cls token embeddings + if self.bos_token_id is not None: + output = output[..., 1:, :] + attention_mask = attention_mask[..., 1:] + if input_ids is not None: + input_ids = input_ids[..., 1:] + # remove eos token embeddings + if self.eos_token_id is not None: + if input_ids is not None: + eos_mask = input_ids.ne(self.eos_token_id).to(output) + input_ids = input_ids[..., 1:] + else: + last_valid_indices = attention_mask.sum(dim=-1) + seq_length = attention_mask.size(-1) + eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1) + output *= eos_mask[:, :, None] + output = output[..., :-1, :] + attention_mask = attention_mask[..., 1:] + + output = self.unfold_kmer_embeddings(output, attention_mask) + output = super().forward(output) + return output diff --git a/multimolecule/module/heads/pretrain.py b/multimolecule/module/heads/pretrain.py new file mode 100644 index 00000000..714cd032 --- /dev/null +++ b/multimolecule/module/heads/pretrain.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import Tuple + +import torch +from torch import Tensor, nn +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ModelOutput + +from multimolecule.models.configuration_utils import PretrainedConfig + +from .transform import HeadTransforms + + +class MaskedLMHead(nn.Module): + """Head for masked language modeling.""" + + def __init__(self, config: PretrainedConfig, weight: Tensor | None = None): + super().__init__() + self.config = config.lm_head if hasattr(config, "lm_head") else config.head + if self.config.hidden_size is None: + self.config.hidden_size = config.hidden_size + self.num_labels = config.vocab_size + self.dropout = nn.Dropout(self.config.dropout) + self.transform = HeadTransforms.build(self.config) + self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=False) + if weight is not None: + self.decoder.weight = weight + if self.config.bias: + self.bias = nn.Parameter(torch.zeros(self.num_labels)) + self.decoder.bias = self.bias + self.activation = ACT2FN[self.config.act] if self.config.act is not None else None + + def forward(self, outputs: ModelOutput | Tuple[Tensor, ...]) -> Tensor: + sequence_output = outputs[0] + output = self.dropout(sequence_output) + output = self.transform(output) + output = self.decoder(output) + if self.activation is not None: + output = self.activation(output) + return output diff --git a/multimolecule/module/heads/sequence.py b/multimolecule/module/heads/sequence.py new file mode 100644 index 00000000..552f9895 --- /dev/null +++ b/multimolecule/module/heads/sequence.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import Tuple + +import torch +from chanfig import ConfigRegistry +from torch import Tensor, nn +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ModelOutput + +from multimolecule.models.configuration_utils import PretrainedConfig + +from .generic import ClassificationHead + + +class SequenceClassificationHead(ClassificationHead): + """Head for sequence-level tasks.""" + + def forward(self, outputs: ModelOutput | Tuple[Tensor, ...]) -> Tensor: # pylint: disable=arguments-renamed + output = super().forward(outputs[1]) + return output diff --git a/multimolecule/module/heads/token.py b/multimolecule/module/heads/token.py new file mode 100644 index 00000000..5e553e99 --- /dev/null +++ b/multimolecule/module/heads/token.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from functools import partial +from typing import Tuple + +import torch +from chanfig import ConfigRegistry +from torch import Tensor, nn +from transformers import PretrainedConfig +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ModelOutput + +from .utils import unfold_kmer_embeddings +from .generic import ClassificationHead + +TokenHeads = ConfigRegistry(key="tokenizer_type") + + +@TokenHeads.register("single", default=True) +class TokenClassificationHead(ClassificationHead): + """Head for token-level tasks.""" + + def forward(self, outputs: ModelOutput | Tuple[Tensor, ...]) -> Tensor: # pylint: disable=arguments-renamed + output = super().forward(outputs[0]) + return output + + +@TokenHeads.register("kmer") +class TokenKMerHead(ClassificationHead): + """Head for token-level tasks.""" + + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.nmers = config.nmers + self.bos_token_id = config.bos_token_id + self.eos_token_id = config.eos_token_id + self.pad_token_id = config.pad_token_id + self.unfold_kmer_embeddings = partial( + unfold_kmer_embeddings, nmers=self.nmers, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id + ) + + def forward( # pylint: disable=arguments-renamed + self, + outputs: ModelOutput | Tuple[Tensor, ...], + attention_mask: Tensor | None = None, + input_ids: Tensor | None = None, + ) -> Tensor: + if attention_mask is None: + if input_ids is None: + raise ValueError("Either attention_mask or input_ids must be provided for TokenKMerHead to work.") + if self.pad_token_id is None: + raise ValueError("pad_token_id must be provided when attention_mask is not passed to TokenKMerHead.") + attention_mask = input_ids.ne(self.pad_token_id) + + output = outputs[0] + output = self.unfold_kmer_embeddings(output, attention_mask) + output = super().forward(output) + return output diff --git a/multimolecule/module/heads/transform.py b/multimolecule/module/heads/transform.py new file mode 100644 index 00000000..872ff433 --- /dev/null +++ b/multimolecule/module/heads/transform.py @@ -0,0 +1,44 @@ +from chanfig import ConfigRegistry +from torch import Tensor, nn +from transformers.activations import ACT2FN + +from multimolecule.models.configuration_utils import BaseHeadConfig + +HeadTransforms = ConfigRegistry(key="transform") + + +@HeadTransforms.register("nonlinear") +class NonLinearTransform(nn.Module): + def __init__(self, config: BaseHeadConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.transform_act, str): + self.transform_act_fn = ACT2FN[config.transform_act] + else: + self.transform_act_fn = config.transform_act + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +@HeadTransforms.register("linear") +class LinearTransform(nn.Module): + def __init__(self, config: BaseHeadConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +@HeadTransforms.register(None) +class IdentityTransform(nn.Identity): + def __init__(self, config: BaseHeadConfig): # pylint: disable=unused-argument + super().__init__() diff --git a/multimolecule/module/heads/utils.py b/multimolecule/module/heads/utils.py new file mode 100644 index 00000000..edf7cd27 --- /dev/null +++ b/multimolecule/module/heads/utils.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from typing import Optional + +import torch +from torch import Tensor + + +def unfold_kmer_embeddings( + embeddings: Tensor, + attention_mask: Tensor, + nmers: int, + bos_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, +) -> Tensor: + r""" + Unfold k-mer embeddings to token embeddings. + + For k-mer input, each embedding column represents k tokens. + This should be fine for sequence level tasks, but sacrifices the resolution for token level tasks. + This function unfolds the k-mer embeddings to token embeddings by sliding averaging the k-mer embeddings. + + For example: + + input tokens = `ACGU` + + 2-mer embeddings = `[, AC, CG, GU, ]`. + + token embeddings = `[, AC, (AC + CG) / 2, (CG + GU) / 2, GU, ]`. + + Args: + embeddings: The k-mer embeddings. + attention_mask: The attention mask. + nmers: The number of tokens in each k-mer. + bos_token_id: The id of the beginning of sequence token. + If not None, the first valid token will not be included in sliding averaging. + eos_token_id: The id of the end of sequence token. + If not None, the last valid token will not be included in sliding averaging. + + Returns: + The token embeddings. + + Examples: + >>> from danling import NestedTensor + >>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(5).repeat(2, 1).T) + 1 + >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 3, True, True) + >>> output[0, :, 0].tolist() + [1.0, 2.0, 2.0, 2.0, 3.0, 0.0, 0.0] + >>> output[1, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0] + >>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1 + >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 4, True, True) + >>> output[0, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0] + >>> output[1, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 7.0] + >>> embeddings = NestedTensor(torch.arange(7).repeat(2, 1).T, torch.arange(11).repeat(2, 1).T) + 1 + >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 5, True, True) + >>> output[0, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0] + >>> output[1, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0, 7.0, 8.0, 8.5, 9.0, 9.5, 10.0, 11.0] + >>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(4).repeat(2, 1).T) + 1 + >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6, True, True) + >>> output[0, :, 0].tolist() + [1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.0] + >>> output[1, :, 0].tolist() + [1.0, 2.0, 2.5, 2.5, 2.5, 2.5, 2.5, 3.0, 4.0] + >>> embeddings = NestedTensor(torch.arange(1).repeat(2, 1).T, torch.arange(2).repeat(2, 1).T) + 1 + >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6) + >>> output[0, :, 0].tolist() + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0] + >>> output[1, :, 0].tolist() + [1.0, 1.5, 1.5, 1.5, 1.5, 1.5, 2.0] + """ + + batch_size, seq_length, hidden_size = embeddings.size() + last_valid_indices = attention_mask.sum(dim=-1) + output = torch.zeros(batch_size, seq_length + nmers - 1, hidden_size, device=embeddings.device) + for index, (tensor, seq_len) in enumerate(zip(embeddings, last_valid_indices)): + embedding = tensor[:seq_len] + if bos_token_id is not None: + embedding = embedding[1:] + if eos_token_id is not None: + embedding = embedding[:-1] + if len(embedding) > nmers: + begin = torch.stack([embedding[:i].mean(0) for i in range(1, nmers)]) + medium = embedding.unfold(0, nmers, 1).mean(-1) + end = torch.stack([embedding[-i:].mean(0) for i in range(nmers - 1, 0, -1)]) + embedding = torch.cat([begin, medium, end]) + elif len(embedding) > 2: + begin = torch.stack([embedding[:i].mean(0) for i in range(1, len(embedding))]) + end = torch.stack([embedding[-i:].mean(0) for i in range(nmers, 0, -1)]) + embedding = torch.cat([begin, end]) + elif len(embedding) == 2: + medium = embedding.mean(0).repeat(nmers - 1, 1) + embedding = torch.cat([embedding[0][None, :], medium, embedding[1][None, :]]) + elif len(embedding) == 1: + embedding = embedding.repeat(nmers, 1) + else: + raise ValueError("Sequence length is less than nmers.") + if bos_token_id is not None: + embedding = torch.cat([tensor[0][None, :], embedding]) + if eos_token_id is not None: + embedding = torch.cat([embedding, tensor[seq_len - 1][None, :]]) + output[index, : seq_len + nmers - 1] = embedding + return output + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x, cos, sin): + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + + return (x * cos) + (rotate_half(x) * sin) + + +def symmetrize(x): + "Make layer symmetric in final two dimensions, used for contact prediction." + return x + x.transpose(-1, -2) + + +def average_product_correct(x): + "Perform average product correct, used for contact prediction." + a1 = x.sum(-1, keepdims=True) + a2 = x.sum(-2, keepdims=True) + a12 = x.sum((-1, -2), keepdims=True) + + avg = a1 * a2 + avg.div_(a12) # in-place to reduce memory + normalized = x - avg + return normalized