From 79aab3bf760bc84463f0b166f1cbef5171348d5f Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Tue, 2 Apr 2024 17:12:13 +0800 Subject: [PATCH] reorganise rnabert --- multimolecule/models/__init__.py | 18 +- multimolecule/models/rnabert/__init__.py | 30 +- .../models/rnabert/configuration_rnabert.py | 9 +- .../models/rnabert/convert_checkpoint.py | 31 +- .../models/rnabert/modeling_rnabert.py | 562 ++++++++++-------- pyproject.toml | 4 + 6 files changed, 406 insertions(+), 248 deletions(-) diff --git a/multimolecule/models/__init__.py b/multimolecule/models/__init__.py index a922cfc5..1f07452e 100644 --- a/multimolecule/models/__init__.py +++ b/multimolecule/models/__init__.py @@ -1,3 +1,17 @@ -from .rnabert import RnaBertConfig, RnaBertModel, RnaTokenizer +from .rnabert import ( + RnaBertConfig, + RnaBertForMaskedLM, + RnaBertForSequenceClassification, + RnaBertForTokenClassification, + RnaBertModel, + RnaTokenizer, +) -__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"] +__all__ = [ + "RnaBertConfig", + "RnaBertModel", + "RnaBertForMaskedLM", + "RnaBertForSequenceClassification", + "RnaBertForTokenClassification", + "RnaTokenizer", +] diff --git a/multimolecule/models/rnabert/__init__.py b/multimolecule/models/rnabert/__init__.py index c388aeff..36c2b893 100644 --- a/multimolecule/models/rnabert/__init__.py +++ b/multimolecule/models/rnabert/__init__.py @@ -1,12 +1,36 @@ -from transformers import AutoConfig, AutoModel, AutoTokenizer +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForMaskedLM, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, + AutoModelWithLMHead, + AutoTokenizer, +) from multimolecule.tokenizers.rna import RnaTokenizer from .configuration_rnabert import RnaBertConfig -from .modeling_rnabert import RnaBertModel +from .modeling_rnabert import ( + RnaBertForMaskedLM, + RnaBertForSequenceClassification, + RnaBertForTokenClassification, + RnaBertModel, +) -__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"] +__all__ = [ + "RnaBertConfig", + "RnaBertModel", + "RnaTokenizer", + "RnaBertForMaskedLM", + "RnaBertForSequenceClassification", + "RnaBertForTokenClassification", +] AutoConfig.register("rnabert", RnaBertConfig) AutoModel.register(RnaBertConfig, RnaBertModel) +AutoModelForMaskedLM.register(RnaBertConfig, RnaBertForMaskedLM) +AutoModelForSequenceClassification.register(RnaBertConfig, RnaBertForSequenceClassification) +AutoModelForTokenClassification.register(RnaBertConfig, RnaBertForTokenClassification) +AutoModelWithLMHead.register(RnaBertConfig, RnaBertForTokenClassification) AutoTokenizer.register(RnaBertConfig, RnaTokenizer) diff --git a/multimolecule/models/rnabert/configuration_rnabert.py b/multimolecule/models/rnabert/configuration_rnabert.py index d98468e2..c2af1179 100644 --- a/multimolecule/models/rnabert/configuration_rnabert.py +++ b/multimolecule/models/rnabert/configuration_rnabert.py @@ -9,7 +9,7 @@ class RnaBertConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`RnaBertModel`]. It is used to instantiate a RnaBert model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the RnaBert - [mana438/RNABERT](https://github.com/mana438/RNABERT/blob/master/RNA_bert_config.json) architecture. + [mana438/RNABERT](https://github.com/mana438/RNABERT) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -55,7 +55,8 @@ class RnaBertConfig(PretrainedConfig): >>> # Initializing a model from the configuration >>> model = RnaBertModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config - ```""" + ``` + """ model_type = "rnabert" @@ -77,6 +78,8 @@ def __init__( pad_token_id=0, position_embedding_type="absolute", use_cache=True, + classifier_dropout=None, + proj_head_mode="nonlinear", **kwargs, ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -97,3 +100,5 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + self.proj_head_mode = proj_head_mode diff --git a/multimolecule/models/rnabert/convert_checkpoint.py b/multimolecule/models/rnabert/convert_checkpoint.py index 0d5214a7..85649b33 100644 --- a/multimolecule/models/rnabert/convert_checkpoint.py +++ b/multimolecule/models/rnabert/convert_checkpoint.py @@ -6,7 +6,7 @@ import torch from torch import nn -from multimolecule.models import RnaBertConfig, RnaBertModel +from multimolecule.models import RnaBertConfig, RnaBertForMaskedLM from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list CONFIG = { @@ -19,7 +19,6 @@ "max_position_embeddings": 440, "num_attention_heads": 12, "num_hidden_layers": 6, - "vocab_size": 25, "ss_vocab_size": 8, "type_vocab_size": 2, "pad_token_id": 0, @@ -33,27 +32,41 @@ def convert_checkpoint(checkpoint_path: str, output_path: Optional[str] = None): if output_path is None: output_path = "rnabert" config = RnaBertConfig.from_dict(chanfig.FlatDict(CONFIG)) + config.vocab_size = len(vocab_list) ckpt = torch.load(checkpoint_path, map_location=torch.device("cpu")) bert_state_dict = ckpt state_dict = {} - model = RnaBertModel(config) + model = RnaBertForMaskedLM(config) for key, value in bert_state_dict.items(): - if key.startswith("module.cls"): - continue - key = key[12:] + key = key[7:] key = key.replace("gamma", "weight") key = key.replace("beta", "bias") - state_dict[key] = value + if key.startswith("bert"): + state_dict["rna" + key] = value + continue + if key.startswith("cls"): + # import ipdb; ipdb.set_trace() + key = "lm_head." + key[4:] + # key = key[4:] + state_dict[key] = value + continue word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + word_embed_weight = word_embed.weight.data + predictions_bias = torch.zeros(config.vocab_size) + predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size)) # nn.init.normal_(pos_embed.weight, std=0.02) for original_token, new_token in zip(original_vocab_list, vocab_list): original_index = original_vocab_list.index(original_token) new_index = vocab_list.index(new_token) - word_embed.weight.data[new_index] = state_dict["embeddings.word_embeddings.weight"][original_index] - state_dict["embeddings.word_embeddings.weight"] = word_embed.weight.data + word_embed_weight[new_index] = state_dict["rnabert.embeddings.word_embeddings.weight"][original_index] + predictions_bias[new_index] = state_dict["lm_head.predictions.bias"][original_index] + predictions_decoder_weight[new_index] = state_dict["lm_head.predictions.decoder.weight"][original_index] + state_dict["rnabert.embeddings.word_embeddings.weight"] = word_embed_weight + state_dict["lm_head.predictions.bias"] = predictions_bias + state_dict["lm_head.predictions.decoder.weight"] = predictions_decoder_weight model.load_state_dict(state_dict) model.save_pretrained(output_path, safe_serialization=True) diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py index b9698589..917de092 100644 --- a/multimolecule/models/rnabert/modeling_rnabert.py +++ b/multimolecule/models/rnabert/modeling_rnabert.py @@ -1,64 +1,291 @@ import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union import torch -from torch import nn +from torch import Tensor, nn from transformers import PreTrainedModel -from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling - +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) + +from ..modeling_utils import SequenceClassificationHead, TokenClassificationHead from .configuration_rnabert import RnaBertConfig -class RnaBertLayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-12): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) # weightのこと - self.bias = nn.Parameter(torch.zeros(hidden_size)) # biasのこと - self.variance_epsilon = eps +class RnaBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ - def forward(self, x): - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.variance_epsilon) - return self.weight * x + self.bias + config_class = RnaBertConfig + base_model_prefix = "rnabert" + supports_gradient_checkpointing = True + _no_split_modules = ["RnaBertLayer", "RnaBertEmbeddings"] + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) -class RnaBertEmbeddings(nn.Module): - def __init__(self, config): - super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, input_ids, token_type_ids=None): - words_embeddings = self.word_embeddings(input_ids) +class RnaBertModel(RnaBertPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = RnaBertEmbeddings(config) + self.encoder = RnaBertEncoder(config) + self.pooler = RnaBertPooler(config) if add_pooling_layer else None + # Initialize weights and apply final processing + self.post_init() + def forward( + self, + input_ids: Tensor, + token_type_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + ): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - seq_length = input_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - position_embeddings = self.position_embeddings(position_ids) + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - embeddings = words_embeddings + position_embeddings + token_type_embeddings + extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + # attention_mask=attention_mask, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - return embeddings + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) -class RnaBertLayer(nn.Module): + +class RnaBertForMaskedLM(RnaBertPreTrainedModel): + def __init__(self, config: RnaBertConfig): + super().__init__(config) + self.rnabert = RnaBertModel(config) + self.lm_head = RnaBertLMHead(config) + + def forward( + self, + input_ids: Tensor, + token_type_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + ): + outputs = self.rnabert( + input_ids, + token_type_ids, + attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + prediction_scores, prediction_scores_ss, seq_relationship_score = self.lm_head( + outputs.last_hidden_state, outputs.pooler_output + ) + + if not return_dict: + return (prediction_scores, prediction_scores_ss) + outputs[2:] + + return RnaBertMaskedLMOutput( + logits=prediction_scores, + logits_ss=prediction_scores_ss, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class RnaBertForSequenceClassification(RnaBertPreTrainedModel): def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.rnabert = RnaBertModel(config, add_pooling_layer=False) + self.classifier = SequenceClassificationHead(config) + + self.init_weights() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.rnabert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return self.classifier(outputs, labels) + + +class RnaBertForTokenClassification(RnaBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.rnabert = RnaBertModel(config, add_pooling_layer=False) + self.classifier = TokenClassificationHead(config) + + self.init_weights() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.rnabert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return self.classifier(outputs, labels) + + +class RnaBertEncoder(nn.Module): + def __init__(self, config: RnaBertConfig): + super().__init__() + self.config = config + self.layer = nn.ModuleList([RnaBertLayer(config) for _ in range(config.num_hidden_layers)]) + # self.layer = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size) + # for _ in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + for layer in self.layer: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore[operator] + + layer_outputs = layer(hidden_states, attention_mask, output_attentions) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) # type: ignore[operator] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore[operator] + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class RnaBertLayer(nn.Module): + def __init__(self, config: RnaBertConfig): super().__init__() self.attention = RnaBertAttention(config) self.intermediate = RnaBertIntermediate(config) self.output = RnaBertOutput(config) - def forward(self, hidden_states, attention_mask, output_attentions=False): + def forward(self, hidden_states: Tensor, attention_mask: Tensor, output_attentions: bool = False): self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions) attention_output, outputs = self_attention_outputs[0], self_attention_outputs[1:] intermediate_output = self.intermediate(attention_output) @@ -68,12 +295,12 @@ def forward(self, hidden_states, attention_mask, output_attentions=False): class RnaBertAttention(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.selfattn = RnaBertSelfAttention(config) self.output = RnaBertSelfOutput(config) - def forward(self, hidden_states, attention_mask, output_attentions=False): + def forward(self, hidden_states: Tensor, attention_mask: Tensor, output_attentions: bool = False): self_outputs = self.selfattn(hidden_states, attention_mask, output_attentions=output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -81,22 +308,17 @@ def forward(self, hidden_states, attention_mask, output_attentions=False): class RnaBertSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.num_attention_heads = config.num_attention_heads - # num_attention_heads': 12 - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): + def transpose_for_scores(self, x: Tensor): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, @@ -104,7 +326,7 @@ def transpose_for_scores(self, x): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask, output_attentions=False): + def forward(self, hidden_states: Tensor, attention_mask: Tensor, output_attentions: bool = False): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) @@ -114,11 +336,8 @@ def forward(self, hidden_states, attention_mask, output_attentions=False): value_layer = self.transpose_for_scores(mixed_value_layer) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) - attention_scores = attention_scores + attention_mask - - attention_probs = nn.Softmax(dim=-1)(attention_scores) - + attention_probs = attention_scores.softmax(-1) attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) @@ -131,275 +350,154 @@ def forward(self, hidden_states, attention_mask, output_attentions=False): class RnaBertSelfOutput(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: Tensor, input_tensor: Tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states -def gelu(x): - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) - - class RnaBertIntermediate(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act - self.intermediate_act_fn = gelu - - def forward(self, hidden_states): + def forward(self, hidden_states: Tensor): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class RnaBertOutput(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: Tensor, input_tensor: Tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states -class RnaBertEncoder(nn.Module): - def __init__(self, config): +class RnaBertEmbeddings(nn.Module): + def __init__(self, config: RnaBertConfig): super().__init__() - self.layer = nn.ModuleList([RnaBertLayer(config) for _ in range(config.num_hidden_layers)]) - # self.layer = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size) - # for _ in range(config.num_hidden_layers)]) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward( - self, - hidden_states, - attention_mask, - output_attentions=False, - output_hidden_states=False, - return_dict=False, - ): - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - for layer in self.layer: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + def forward(self, input_ids: Tensor, token_type_ids: Optional[Tensor] = None): + words_embeddings = self.word_embeddings(input_ids) - layer_outputs = layer(hidden_states, attention_mask, output_attentions) - hidden_states = layer_outputs[0] + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + position_embeddings = self.position_embeddings(position_ids) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + embeddings = words_embeddings + position_embeddings + token_type_embeddings - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - all_hidden_states, - all_self_attentions, - ] - if v is not None - ) - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + + return embeddings class RnaBertPooler(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() - def forward(self, hidden_states): + def forward(self, hidden_states: Tensor): first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output -class RnaBertPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = RnaBertConfig - base_model_prefix = "rnabert" - supports_gradient_checkpointing = True - _no_split_modules = ["RnaBertLayer", "RnaBertFoldTriangularSelfAttentionBlock", "RnaBertEmbeddings"] - - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - -class RnaBertModel(RnaBertPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.embeddings = RnaBertEmbeddings(config) - self.encoder = RnaBertEncoder(config) - self.pooler = RnaBertPooler(config) - - def forward( - self, - input_ids, - token_type_ids=None, - attention_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - - extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - embedding_output = self.embeddings( - input_ids=input_ids, - token_type_ids=token_type_ids, - # attention_mask=attention_mask, - ) - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] +class RnaBertLayerNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-12): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) # weightのこと + self.bias = nn.Parameter(torch.zeros(hidden_size)) # biasのこと + self.variance_epsilon = eps - return BaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) + def forward(self, x: Tensor): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias class RnaBertLMHead(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.predictions = MaskedWordPredictions(config, config.vocab_size) self.predictions_ss = MaskedWordPredictions(config, config.ss_vocab_size) - self.seq_relationship = nn.Linear(config.hidden_size, 2) - def forward(self, sequence_output, pooled_output): + def forward(self, sequence_output: Tensor, pooled_output: Tensor): prediction_scores = self.predictions(sequence_output) prediction_scores_ss = self.predictions_ss(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, prediction_scores_ss, seq_relationship_score class MaskedWordPredictions(nn.Module): def __init__(self, config, vocab_size): super().__init__() - self.transform = RnaBertPredictionHeadTransform(config) - self.decoder = nn.Linear(in_features=config.hidden_size, out_features=vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(vocab_size)) + # self.decoder.bias = self.bias - def forward(self, hidden_states): + def forward(self, hidden_states: Tensor): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) + self.bias - return hidden_states class RnaBertPredictionHeadTransform(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - - self.transform_act_fn = gelu - + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) - def forward(self, hidden_states): + def forward(self, hidden_states: Tensor): hidden_states = self.dense(hidden_states) + # Note: Commted out in the original code # hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states -class RnaBertForMaskedLM(nn.Module): - def __init__(self, config): - super().__init__() - self.bert = RnaBertModel(config) - self.lm_head = RnaBertLMHead(config) - - def forward( - self, - input_ids, - token_type_ids=None, - attention_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=False, - ): - outputs = self.bert( - input_ids, - token_type_ids, - attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - prediction_scores, prediction_scores_ss, seq_relationship_score = self.lm_head( - outputs.last_hidden_state, outputs.pooler_output - ) - return prediction_scores, prediction_scores_ss, outputs +@dataclass +class RnaBertMaskedLMOutput(MaskedLMOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + logits_ss: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None diff --git a/pyproject.toml b/pyproject.toml index 536d2e09..5edeb3ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,10 @@ classifiers = [ dynamic = [ "version", ] +dependencies = [ + "chanfig", + "transformers", +] [project.urls] documentation = "https://multimolecule.danling.org" homepage = "https://multimolecule.danling.org"