Skip to content

Commit

Permalink
add support of 3UTR-BERT
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Apr 18, 2024
1 parent c2b7f19 commit a6bfcd6
Show file tree
Hide file tree
Showing 7 changed files with 1,429 additions and 1 deletion.
12 changes: 12 additions & 0 deletions multimolecule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
SpliceBertForSequenceClassification,
SpliceBertForTokenClassification,
SpliceBertModel,
UtrBertConfig,
UtrBertForMaskedLM,
UtrBertForPretraining,
UtrBertForSequenceClassification,
UtrBertForTokenClassification,
UtrBertModel,
UtrLmConfig,
UtrLmForMaskedLM,
UtrLmForPretraining,
Expand Down Expand Up @@ -61,6 +67,12 @@
"SpliceBertForPretraining",
"SpliceBertForSequenceClassification",
"SpliceBertForTokenClassification",
"UtrBertConfig",
"UtrBertModel",
"UtrBertForMaskedLM",
"UtrBertForPretraining",
"UtrBertForSequenceClassification",
"UtrBertForTokenClassification",
"UtrLmConfig",
"UtrLmModel",
"UtrLmForMaskedLM",
Expand Down
14 changes: 14 additions & 0 deletions multimolecule/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@
SpliceBertForTokenClassification,
SpliceBertModel,
)
from .utrbert import (
UtrBertConfig,
UtrBertForMaskedLM,
UtrBertForPretraining,
UtrBertForSequenceClassification,
UtrBertForTokenClassification,
UtrBertModel,
)
from .utrlm import (
UtrLmConfig,
UtrLmForMaskedLM,
Expand Down Expand Up @@ -66,6 +74,12 @@
"SpliceBertForPretraining",
"SpliceBertForSequenceClassification",
"SpliceBertForTokenClassification",
"UtrBertConfig",
"UtrBertModel",
"UtrBertForMaskedLM",
"UtrBertForPretraining",
"UtrBertForSequenceClassification",
"UtrBertForTokenClassification",
"UtrLmConfig",
"UtrLmForMaskedLM",
"UtrLmForPretraining",
Expand Down
2 changes: 1 addition & 1 deletion multimolecule/models/splicebert/modeling_splicebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def __init__(self, config: SpliceBertConfig):
"If you want to use `SpliceBertForPretraining` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
self.splicebert = SpliceBertModel(config, add_pooling_layer=False)
self.splicebert = SpliceBertModel(config, add_pooling_layer=True)
self.lm_head = MaskedLMHead(config)

# Initialize weights and apply final processing
Expand Down
39 changes: 39 additions & 0 deletions multimolecule/models/utrbert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from transformers import (
AutoConfig,
AutoModel,
AutoModelForMaskedLM,
AutoModelForPreTraining,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoTokenizer,
)

from multimolecule.tokenizers.rna import RnaTokenizer

from .configuration_utrbert import UtrBertConfig
from .modeling_utrbert import (
UtrBertForMaskedLM,
UtrBertForPretraining,
UtrBertForSequenceClassification,
UtrBertForTokenClassification,
UtrBertModel,
)

__all__ = [
"UtrBertConfig",
"UtrBertModel",
"UtrBertForMaskedLM",
"UtrBertForPretraining",
"UtrBertForSequenceClassification",
"UtrBertForTokenClassification",
]

AutoConfig.register("utrbert", UtrBertConfig)
AutoModel.register(UtrBertConfig, UtrBertModel)
AutoModelForMaskedLM.register(UtrBertConfig, UtrBertForMaskedLM)
AutoModelForPreTraining.register(UtrBertConfig, UtrBertForPretraining)
AutoModelForSequenceClassification.register(UtrBertConfig, UtrBertForSequenceClassification)
AutoModelForTokenClassification.register(UtrBertConfig, UtrBertForTokenClassification)
AutoModelWithLMHead.register(UtrBertConfig, UtrBertForTokenClassification)
AutoTokenizer.register(UtrBertConfig, RnaTokenizer)
113 changes: 113 additions & 0 deletions multimolecule/models/utrbert/configuration_utrbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)


class UtrBertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`UtrBertModel`]. It is used to instantiate a
3UTRBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the 3UTRBERT
[yangyn533/3UTRBERT](https://github.com/yangyn533/3UTRBERT) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*):
Vocabulary size of the UTRBERT model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`BertModel`].
nmers (`int`, *optional*):
kmer size of the UTRBERT model. Defines the vocabulary size of the model.
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
hidden_dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (`int`, *optional*, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (`int`, *optional*, defaults to 2):
The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`].
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
is_decoder (`bool`, *optional*, defaults to `False`):
Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
Examples:
>>> from multimolecule import UtrBertConfig, UtrBertModel
>>> # Initializing a UtrBERT multimolecule/utrbert style configuration
>>> configuration = UtrBertConfig()
>>> # Initializing a model (with random weights) from the multimolecule/utrbert style configuration
>>> model = UtrBertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""

model_type = "utrbert"

def __init__(
self,
vocab_size=None,
nmers=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)

self.vocab_size = vocab_size
self.nmers = nmers
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
124 changes: 124 additions & 0 deletions multimolecule/models/utrbert/convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
from typing import Optional

import chanfig
import torch
from torch import nn

from multimolecule.models import SpliceBertConfig as Config
from multimolecule.models import SpliceBertForPretraining as Model
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list

try:
from huggingface_hub import HfApi
except ImportError:
HfApi = None

torch.manual_seed(1013)


def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_list):
state_dict = {}
for key, value in original_state_dict.items():
key = key.replace("LayerNorm", "layer_norm")
key = key.replace("gamma", "weight")
key = key.replace("beta", "bias")
if key.startswith("bert"):
state_dict["splice" + key] = value
continue
if key.startswith("cls"):
key = "lm_head" + key[15:]
state_dict[key] = value
continue
state_dict[key] = value

state_vocab_size = state_dict["splicebert.embeddings.word_embeddings.weight"].size(0)
original_vocab_size = len(original_vocab_list)
if state_vocab_size != original_vocab_size:
raise ValueError(
f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}."
)
word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size))
predictions_bias = torch.zeros(config.vocab_size)
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_index, original_token in enumerate(original_vocab_list):
new_index = vocab_list.index(original_token)
word_embed_weight[new_index] = state_dict["splicebert.embeddings.word_embeddings.weight"][original_index]
predictions_decoder_weight[new_index] = state_dict["lm_head.decoder.weight"][original_index]
predictions_bias[new_index] = state_dict["lm_head.decoder.bias"][original_index]
state_dict["splicebert.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["lm_head.decoder.weight"] = predictions_decoder_weight
state_dict["lm_head.decoder.bias"] = state_dict["lm_head.bias"] = predictions_bias
return state_dict


def convert_checkpoint(convert_config):
config = chanfig.load(os.path.join(convert_config.checkpoint_path, "config.json"))
config.hidden_dropout = config.pop("hidden_dropout_prob", 0.1)
config.attention_dropout = config.pop("attention_probs_dropout_prob", 0.1)
config.nmers = int(convert_config.checkpoint_path.split("/")[-1][0])
vocab_list = get_vocab_list(config.nmers)
config = Config.from_dict(config)
del config._name_or_path
config.architectures = ["SpliceBertModel"]
config.vocab_size = len(vocab_list)

model = Model(config)

ckpt = torch.load(
os.path.join(convert_config.checkpoint_path, "pytorch_model.bin"), map_location=torch.device("cpu")
)
original_vocab_list = []
for char in open(os.path.join(convert_config.checkpoint_path, "vocab.txt")).read().splitlines(): # noqa: SIM115
if char.startswith("["):
char = char.lower().replace("[", "<").replace("]", ">")
if char == "T":
char = "U"
if char == "<sep>":
char = "<eos>"
original_vocab_list.append(char)
state_dict = _convert_checkpoint(config, ckpt, vocab_list, original_vocab_list)

model.load_state_dict(state_dict)
model.save_pretrained(convert_config.output_path, safe_serialization=True)
model.save_pretrained(convert_config.output_path, safe_serialization=False)
chanfig.NestedDict(get_special_tokens_map()).json(
os.path.join(convert_config.output_path, "special_tokens_map.json")
)
chanfig.NestedDict(get_tokenizer_config()).json(os.path.join(convert_config.output_path, "tokenizer_config.json"))

if convert_config.push_to_hub:
if HfApi is None:
raise ImportError("Please install huggingface_hub to push to the hub.")
api = HfApi()
api.create_repo(
convert_config.repo_id,
token=convert_config.token,
exist_ok=True,
)
api.upload_folder(
repo_id=convert_config.repo_id, folder_path=convert_config.output_path, token=convert_config.token
)


@chanfig.configclass
class ConvertConfig:
checkpoint_path: str
output_path: Optional[str] = None
push_to_hub: bool = False
repo_id: Optional[str] = output_path
token: Optional[str] = None

def post(self):
if self.output_path is None:
self.output_path = self.checkpoint_path.split("/")[-1].lower()
if self.repo_id is None:
self.repo_id = f"multimolecule/{self.output_path}"


if __name__ == "__main__":
config = ConvertConfig()
config.parse() # type: ignore[attr-defined]
convert_checkpoint(config)
Loading

0 comments on commit a6bfcd6

Please sign in to comment.