Skip to content

Commit

Permalink
add support of RnaFm
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 18, 2024
1 parent 517031f commit 575551f
Show file tree
Hide file tree
Showing 6 changed files with 1,486 additions and 0 deletions.
12 changes: 12 additions & 0 deletions multimolecule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
RnaBertModel,
RnaFmConfig,
RnaFmForMaskedLM,
RnaFmForPretraining,
RnaFmForSequenceClassification,
RnaFmForTokenClassification,
RnaFmModel,
RnaMsmConfig,
RnaMsmForMaskedLM,
RnaMsmForPretraining,
Expand All @@ -31,6 +37,12 @@
"RnaBertForPretraining",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
"RnaFmConfig",
"RnaFmModel",
"RnaFmForMaskedLM",
"RnaFmForPretraining",
"RnaFmForSequenceClassification",
"RnaFmForTokenClassification",
"RnaMsmConfig",
"RnaMsmModel",
"RnaMsmForMaskedLM",
Expand Down
14 changes: 14 additions & 0 deletions multimolecule/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
RnaBertForTokenClassification,
RnaBertModel,
)
from .rnafm import (
RnaFmConfig,
RnaFmForMaskedLM,
RnaFmForPretraining,
RnaFmForSequenceClassification,
RnaFmForTokenClassification,
RnaFmModel,
)
from .rnamsm import (
RnaMsmConfig,
RnaMsmForMaskedLM,
Expand All @@ -32,6 +40,12 @@
"RnaBertForPretraining",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
"RnaFmConfig",
"RnaFmForMaskedLM",
"RnaFmForPretraining",
"RnaFmForSequenceClassification",
"RnaFmForTokenClassification",
"RnaFmModel",
"RnaMsmConfig",
"RnaMsmModel",
"RnaMsmForMaskedLM",
Expand Down
40 changes: 40 additions & 0 deletions multimolecule/models/rnafm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from transformers import (
AutoConfig,
AutoModel,
AutoModelForMaskedLM,
AutoModelForPreTraining,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoTokenizer,
)

from multimolecule.tokenizers.rna import RnaTokenizer

from .configuration_rnafm import RnaFmConfig
from .modeling_rnafm import (
RnaFmForMaskedLM,
RnaFmForPretraining,
RnaFmForSequenceClassification,
RnaFmForTokenClassification,
RnaFmModel,
)

__all__ = [
"RnaFmConfig",
"RnaFmModel",
"RnaTokenizer",
"RnaFmForMaskedLM",
"RnaFmForPretraining",
"RnaFmForSequenceClassification",
"RnaFmForTokenClassification",
]

AutoConfig.register("rnafm", RnaFmConfig)
AutoModel.register(RnaFmConfig, RnaFmModel)
AutoModelForMaskedLM.register(RnaFmConfig, RnaFmForMaskedLM)
AutoModelForPreTraining.register(RnaFmConfig, RnaFmForPretraining)
AutoModelForSequenceClassification.register(RnaFmConfig, RnaFmForSequenceClassification)
AutoModelForTokenClassification.register(RnaFmConfig, RnaFmForTokenClassification)
AutoModelWithLMHead.register(RnaFmConfig, RnaFmForTokenClassification)
AutoTokenizer.register(RnaFmConfig, RnaTokenizer)
125 changes: 125 additions & 0 deletions multimolecule/models/rnafm/configuration_rnafm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from transformers.utils import logging

from ..configuration_utils import HeadConfig, MaskedLMHeadConfig, PretrainedConfig

logger = logging.get_logger(__name__)


class RnaFmConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`RnaFmModel`]. It is used to instantiate a RNA-FM
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the RNA-FM
[ml4bio/RNA-FM](https://github.com/ml4bio/RNA-FM) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 25):
Vocabulary size of the RNA-FM model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`RnaFmModel`].
hidden_size (`int`, *optional*, defaults to 640):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 20):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 5120):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (`int`, *optional*, defaults to 1026):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query", "rotary"`.
For positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
is_decoder (`bool`, *optional*, defaults to `False`):
Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
emb_layer_norm_before (`bool`, *optional*):
Whether to apply layer normalization after embeddings but before the main stem of the network.
token_dropout (`bool`, defaults to `False`):
When this is enabled, masked tokens are treated as if they had been dropped out by input dropout.
Examples:
```python
>>> from multimolecule import RnaFmModel, RnaFmConfig
>>> # Initializing a RNA-FM multimolecule/rnafm style configuration
>>> configuration = RnaFmConfig()
>>> # Initializing a model (with random weights) from the multimolecule/rnafm style configuration
>>> model = RnaFmModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""

model_type = "rnafm"

def __init__(
self,
vocab_size=25,
hidden_size=640,
num_hidden_layers=12,
num_attention_heads=20,
intermediate_size=5120,
hidden_act="gelu",
hidden_dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=1026,
initializer_range=0.02,
layer_norm_eps=1e-12,
position_embedding_type="absolute",
use_cache=True,
emb_layer_norm_before=True,
token_dropout=True,
head=None,
lm_head=None,
**kwargs,
):
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
146 changes: 146 additions & 0 deletions multimolecule/models/rnafm/convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os
from typing import Optional

import chanfig
import torch
from torch import nn

from multimolecule.models import RnaFmConfig as Config
from multimolecule.models import RnaFmForPretraining as Model
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list

try:
from huggingface_hub import HfApi
except ImportError:
HfApi = None


torch.manual_seed(1013)

original_vocab_list = [
"<cls>",
"<pad>",
"<eos>",
"<unk>",
"A",
"C",
"G",
"U",
"R",
"Y",
"K",
"M",
"S",
"W",
"B",
"D",
"H",
"V",
"N",
"-",
"<null>",
"<null>",
"<null>",
"<null>",
"<mask>",
]
vocab_list = get_vocab_list()


def _convert_checkpoint(config, original_state_dict):
state_dict = {}
for key, value in original_state_dict.items():
key = "rnafm" + key[7:]
key = key.replace("LayerNorm", "layer_norm")
key = key.replace("gamma", "weight")
key = key.replace("beta", "bias")
key = key.replace("rnafm.encoder.emb_layer_norm_before", "rnafm.embeddings.layer_norm")
key = key.replace("rnafm.encoder.embed_tokens", "rnafm.embeddings.word_embeddings")
key = key.replace("rnafm.encoder.embed_positions", "rnafm.embeddings.position_embeddings")
key = key.replace("layers", "layer")
key = key.replace("self_attn", "attention.self")
key = key.replace("q_proj", "query")
key = key.replace("k_proj", "key")
key = key.replace("v_proj", "value")
key = key.replace("self.out_proj", "output.dense")
key = key.replace("self_layer_norm", "layer_norm")
key = key.replace("final_layer_norm", "layer_norm")
key = key.replace("fc1", "intermediate.dense")
key = key.replace("fc2", "output.dense")
key = key.replace("regression", "decoder")
key = key.replace("rnafm.encoder.lm_head", "pretrain_head.predictions")
key = key.replace("predictions.dense", "predictions.transform.dense")
key = key.replace("predictions.layer_norm", "predictions.transform.layer_norm")
key = key.replace("predictions.weight", "predictions.decoder.weight")
key = key.replace("rnafm.encoder.contact_head", "pretrain_head.contact")
state_dict[key] = value

state_vocab_size = state_dict["rnafm.embeddings.word_embeddings.weight"].size(0)
original_vocab_size = len(original_vocab_list)
if state_vocab_size != original_vocab_size:
raise ValueError(
f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}."
)
word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size))
predictions_bias = torch.zeros(config.vocab_size)
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_index, original_token in enumerate(original_vocab_list):
new_index = vocab_list.index(original_token)
word_embed_weight[new_index] = state_dict["rnafm.embeddings.word_embeddings.weight"][original_index]
predictions_decoder_weight[new_index] = state_dict["pretrain_head.predictions.decoder.weight"][original_index]
predictions_bias[new_index] = state_dict["pretrain_head.predictions.bias"][original_index]
state_dict["rnafm.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["pretrain_head.predictions.decoder.weight"] = predictions_decoder_weight
state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = (
predictions_bias
)
return state_dict


def convert_checkpoint(convert_config):
config = Config(num_labels=1)
config.architectures = ["RnaFmModel"]
config.vocab_size = len(vocab_list)

model = Model(config)

ckpt = torch.load(convert_config.checkpoint_path, map_location=torch.device("cpu"))
state_dict = _convert_checkpoint(config, ckpt)

model.load_state_dict(state_dict)
model.save_pretrained(convert_config.output_path, safe_serialization=True)
model.save_pretrained(convert_config.output_path, safe_serialization=False)
chanfig.NestedDict(get_special_tokens_map()).json(
os.path.join(convert_config.output_path, "special_tokens_map.json")
)
chanfig.NestedDict(get_tokenizer_config()).json(os.path.join(convert_config.output_path, "tokenizer_config.json"))

if convert_config.push_to_hub:
if HfApi is None:
raise ImportError("Please install huggingface_hub to push to the hub.")
api = HfApi()
api.create_repo(
convert_config.repo_id,
token=convert_config.token,
exist_ok=True,
)
api.upload_folder(
repo_id=convert_config.repo_id, folder_path=convert_config.output_path, token=convert_config.token
)


@chanfig.configclass
class ConvertConfig:
checkpoint_path: str
output_path: str = Config.model_type
push_to_hub: bool = False
repo_id: str = f"multimolecule/{output_path}"
token: Optional[str] = None


if __name__ == "__main__":
config = ConvertConfig()
config.parse() # type: ignore[attr-defined]
convert_checkpoint(config)
Loading

0 comments on commit 575551f

Please sign in to comment.