Skip to content

Commit

Permalink
reorganise rnabert
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Apr 2, 2024
1 parent d4e71c8 commit 79aab3b
Show file tree
Hide file tree
Showing 6 changed files with 406 additions and 248 deletions.
18 changes: 16 additions & 2 deletions multimolecule/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
from .rnabert import RnaBertConfig, RnaBertModel, RnaTokenizer
from .rnabert import (
RnaBertConfig,
RnaBertForMaskedLM,
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
RnaBertModel,
RnaTokenizer,
)

__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"]
__all__ = [
"RnaBertConfig",
"RnaBertModel",
"RnaBertForMaskedLM",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
"RnaTokenizer",
]
30 changes: 27 additions & 3 deletions multimolecule/models/rnabert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,36 @@
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers import (
AutoConfig,
AutoModel,
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoTokenizer,
)

from multimolecule.tokenizers.rna import RnaTokenizer

from .configuration_rnabert import RnaBertConfig
from .modeling_rnabert import RnaBertModel
from .modeling_rnabert import (
RnaBertForMaskedLM,
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
RnaBertModel,
)

__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"]
__all__ = [
"RnaBertConfig",
"RnaBertModel",
"RnaTokenizer",
"RnaBertForMaskedLM",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
]

AutoConfig.register("rnabert", RnaBertConfig)
AutoModel.register(RnaBertConfig, RnaBertModel)
AutoModelForMaskedLM.register(RnaBertConfig, RnaBertForMaskedLM)
AutoModelForSequenceClassification.register(RnaBertConfig, RnaBertForSequenceClassification)
AutoModelForTokenClassification.register(RnaBertConfig, RnaBertForTokenClassification)
AutoModelWithLMHead.register(RnaBertConfig, RnaBertForTokenClassification)
AutoTokenizer.register(RnaBertConfig, RnaTokenizer)
9 changes: 7 additions & 2 deletions multimolecule/models/rnabert/configuration_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class RnaBertConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`RnaBertModel`]. It is used to instantiate a
RnaBert model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the RnaBert
[mana438/RNABERT](https://github.com/mana438/RNABERT/blob/master/RNA_bert_config.json) architecture.
[mana438/RNABERT](https://github.com/mana438/RNABERT) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Expand Down Expand Up @@ -55,7 +55,8 @@ class RnaBertConfig(PretrainedConfig):
>>> # Initializing a model from the configuration >>> model = RnaBertModel(configuration)
>>> # Accessing the model configuration >>> configuration = model.config
```"""
```
"""

model_type = "rnabert"

Expand All @@ -77,6 +78,8 @@ def __init__(
pad_token_id=0,
position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None,
proj_head_mode="nonlinear",
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
Expand All @@ -97,3 +100,5 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
self.proj_head_mode = proj_head_mode
31 changes: 22 additions & 9 deletions multimolecule/models/rnabert/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch import nn

from multimolecule.models import RnaBertConfig, RnaBertModel
from multimolecule.models import RnaBertConfig, RnaBertForMaskedLM
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list

CONFIG = {
Expand All @@ -19,7 +19,6 @@
"max_position_embeddings": 440,
"num_attention_heads": 12,
"num_hidden_layers": 6,
"vocab_size": 25,
"ss_vocab_size": 8,
"type_vocab_size": 2,
"pad_token_id": 0,
Expand All @@ -33,27 +32,41 @@ def convert_checkpoint(checkpoint_path: str, output_path: Optional[str] = None):
if output_path is None:
output_path = "rnabert"
config = RnaBertConfig.from_dict(chanfig.FlatDict(CONFIG))
config.vocab_size = len(vocab_list)
ckpt = torch.load(checkpoint_path, map_location=torch.device("cpu"))
bert_state_dict = ckpt
state_dict = {}

model = RnaBertModel(config)
model = RnaBertForMaskedLM(config)

for key, value in bert_state_dict.items():
if key.startswith("module.cls"):
continue
key = key[12:]
key = key[7:]
key = key.replace("gamma", "weight")
key = key.replace("beta", "bias")
state_dict[key] = value
if key.startswith("bert"):
state_dict["rna" + key] = value
continue
if key.startswith("cls"):
# import ipdb; ipdb.set_trace()
key = "lm_head." + key[4:]
# key = key[4:]
state_dict[key] = value
continue

word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_bias = torch.zeros(config.vocab_size)
predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size))
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_token, new_token in zip(original_vocab_list, vocab_list):
original_index = original_vocab_list.index(original_token)
new_index = vocab_list.index(new_token)
word_embed.weight.data[new_index] = state_dict["embeddings.word_embeddings.weight"][original_index]
state_dict["embeddings.word_embeddings.weight"] = word_embed.weight.data
word_embed_weight[new_index] = state_dict["rnabert.embeddings.word_embeddings.weight"][original_index]
predictions_bias[new_index] = state_dict["lm_head.predictions.bias"][original_index]
predictions_decoder_weight[new_index] = state_dict["lm_head.predictions.decoder.weight"][original_index]
state_dict["rnabert.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["lm_head.predictions.bias"] = predictions_bias
state_dict["lm_head.predictions.decoder.weight"] = predictions_decoder_weight

model.load_state_dict(state_dict)
model.save_pretrained(output_path, safe_serialization=True)
Expand Down
Loading

0 comments on commit 79aab3b

Please sign in to comment.