Skip to content

Commit

Permalink
add support of SpliceBert
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 18, 2024
1 parent ae6ace9 commit 517031f
Show file tree
Hide file tree
Showing 6 changed files with 1,348 additions and 0 deletions.
12 changes: 12 additions & 0 deletions multimolecule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
RnaMsmForSequenceClassification,
RnaMsmForTokenClassification,
RnaMsmModel,
SpliceBertConfig,
SpliceBertForMaskedLM,
SpliceBertForPretraining,
SpliceBertForSequenceClassification,
SpliceBertForTokenClassification,
SpliceBertModel,
)
from .tokenizers import RnaTokenizer

Expand All @@ -31,4 +37,10 @@
"RnaMsmForPretraining",
"RnaMsmForSequenceClassification",
"RnaMsmForTokenClassification",
"SpliceBertConfig",
"SpliceBertModel",
"SpliceBertForMaskedLM",
"SpliceBertForPretraining",
"SpliceBertForSequenceClassification",
"SpliceBertForTokenClassification",
]
14 changes: 14 additions & 0 deletions multimolecule/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
RnaMsmForTokenClassification,
RnaMsmModel,
)
from .splicebert import (
SpliceBertConfig,
SpliceBertForMaskedLM,
SpliceBertForPretraining,
SpliceBertForSequenceClassification,
SpliceBertForTokenClassification,
SpliceBertModel,
)

__all__ = [
"RnaTokenizer",
Expand All @@ -30,4 +38,10 @@
"RnaMsmForPretraining",
"RnaMsmForSequenceClassification",
"RnaMsmForTokenClassification",
"SpliceBertConfig",
"SpliceBertModel",
"SpliceBertForMaskedLM",
"SpliceBertForPretraining",
"SpliceBertForSequenceClassification",
"SpliceBertForTokenClassification",
]
39 changes: 39 additions & 0 deletions multimolecule/models/splicebert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from transformers import (
AutoConfig,
AutoModel,
AutoModelForMaskedLM,
AutoModelForPreTraining,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoTokenizer,
)

from multimolecule.tokenizers.rna import RnaTokenizer

from .configuration_splicebert import SpliceBertConfig
from .modeling_splicebert import (
SpliceBertForMaskedLM,
SpliceBertForPretraining,
SpliceBertForSequenceClassification,
SpliceBertForTokenClassification,
SpliceBertModel,
)

__all__ = [
"SpliceBertConfig",
"SpliceBertModel",
"SpliceBertForMaskedLM",
"SpliceBertForPretraining",
"SpliceBertForSequenceClassification",
"SpliceBertForTokenClassification",
]

AutoConfig.register("splicebert", SpliceBertConfig)
AutoModel.register(SpliceBertConfig, SpliceBertModel)
AutoModelForMaskedLM.register(SpliceBertConfig, SpliceBertForMaskedLM)
AutoModelForPreTraining.register(SpliceBertConfig, SpliceBertForPretraining)
AutoModelForSequenceClassification.register(SpliceBertConfig, SpliceBertForSequenceClassification)
AutoModelForTokenClassification.register(SpliceBertConfig, SpliceBertForTokenClassification)
AutoModelWithLMHead.register(SpliceBertConfig, SpliceBertForTokenClassification)
AutoTokenizer.register(SpliceBertConfig, RnaTokenizer)
107 changes: 107 additions & 0 deletions multimolecule/models/splicebert/configuration_splicebert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from transformers.utils import logging

from ..configuration_utils import HeadConfig, MaskedLMHeadConfig, PretrainedConfig

logger = logging.get_logger(__name__)


class SpliceBertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SpliceBertModel`]. It is used to instantiate a
SpliceBert model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the SpliceBert
[biomed-AI/SpliceBERT](https://github.com/biomed-AI/SpliceBERT) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 25):
Vocabulary size of the SpliceBert model. Defines the number of different tokens that can be represented by
the `inputs_ids` passed when calling [`SpliceBertModel`].
hidden_size (`int`, *optional*, defaults to 512):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 6):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 2048):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (`int`, *optional*, defaults to 1026):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
Examples:
```python
>>> from multimolecule import SpliceBertModel, SpliceBertConfig
>>> # Initializing a SpliceBERT multimolecule/splicebert style configuration
>>> configuration = SpliceBertConfig()
>>> # Initializing a model (with random weights) from the multimolecule/splicebert style configuration
>>> model = SpliceBertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""

model_type = "splicebert"

def __init__(
self,
vocab_size=25,
hidden_size=512,
num_hidden_layers=6,
num_attention_heads=16,
intermediate_size=2048,
hidden_act="gelu",
hidden_dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=1026,
initializer_range=0.02,
layer_norm_eps=1e-12,
position_embedding_type="absolute",
use_cache=True,
head=None,
lm_head=None,
**kwargs,
):
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
self.type_vocab_size = 2
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
124 changes: 124 additions & 0 deletions multimolecule/models/splicebert/convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
from typing import Optional

import chanfig
import torch
from torch import nn

from multimolecule.models import SpliceBertConfig as Config
from multimolecule.models import SpliceBertForPretraining as Model
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list

try:
from huggingface_hub import HfApi
except ImportError:
HfApi = None


torch.manual_seed(1013)
vocab_list = get_vocab_list()


def _convert_checkpoint(config, original_state_dict, original_vocab_list):
state_dict = {}
for key, value in original_state_dict.items():
key = key.replace("LayerNorm", "layer_norm")
key = key.replace("gamma", "weight")
key = key.replace("beta", "bias")
if key.startswith("bert"):
state_dict["splice" + key] = value
continue
if key.startswith("cls"):
key = "lm_head" + key[15:]
state_dict[key] = value
continue

state_vocab_size = state_dict["splicebert.embeddings.word_embeddings.weight"].size(0)
original_vocab_size = len(original_vocab_list)
if state_vocab_size != original_vocab_size:
raise ValueError(
f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}."
)
word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size))
predictions_bias = torch.zeros(config.vocab_size)
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_index, original_token in enumerate(original_vocab_list):
new_index = vocab_list.index(original_token)
word_embed_weight[new_index] = state_dict["splicebert.embeddings.word_embeddings.weight"][original_index]
predictions_decoder_weight[new_index] = state_dict["lm_head.decoder.weight"][original_index]
predictions_bias[new_index] = state_dict["lm_head.decoder.bias"][original_index]
state_dict["splicebert.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["lm_head.decoder.weight"] = predictions_decoder_weight
state_dict["lm_head.decoder.bias"] = state_dict["lm_head.bias"] = predictions_bias
del state_dict["splicebert.embeddings.position_ids"]
return state_dict


def convert_checkpoint(convert_config):
config = chanfig.load(os.path.join(convert_config.checkpoint_path, "config.json"))
config.hidden_dropout = config.pop("hidden_dropout_prob")
config.attention_dropout = config.pop("attention_probs_dropout_prob")
config = Config.from_dict(config)
del config._name_or_path
config.architectures = ["SpliceBertModel"]
config.vocab_size = len(vocab_list)

model = Model(config)

ckpt = torch.load(
os.path.join(convert_config.checkpoint_path, "pytorch_model.bin"), map_location=torch.device("cpu")
)
vocab = []
for char in open(os.path.join(convert_config.checkpoint_path, "vocab.txt")).read().splitlines(): # noqa: SIM115
if char.startswith("["):
char = char.lower().replace("[", "<").replace("]", ">")
if char == "T":
char = "U"
if char == "<sep>":
char = "<eos>"
vocab.append(char)
state_dict = _convert_checkpoint(config, ckpt, vocab)

model.load_state_dict(state_dict)
model.save_pretrained(convert_config.output_path, safe_serialization=True)
model.save_pretrained(convert_config.output_path, safe_serialization=False)
chanfig.NestedDict(get_special_tokens_map()).json(
os.path.join(convert_config.output_path, "special_tokens_map.json")
)
chanfig.NestedDict(get_tokenizer_config()).json(os.path.join(convert_config.output_path, "tokenizer_config.json"))

if convert_config.push_to_hub:
if HfApi is None:
raise ImportError("Please install huggingface_hub to push to the hub.")
api = HfApi()
api.create_repo(
convert_config.repo_id,
token=convert_config.token,
exist_ok=True,
)
api.upload_folder(
repo_id=convert_config.repo_id, folder_path=convert_config.output_path, token=convert_config.token
)


@chanfig.configclass
class ConvertConfig:
checkpoint_path: str
output_path: Optional[str] = None
push_to_hub: bool = False
repo_id: Optional[str] = output_path
token: Optional[str] = None

def post(self):
if self.output_path is None:
self.output_path = self.checkpoint_path.split("/")[-1].lower()
if self.repo_id is None:
self.repo_id = f"multimolecule/{self.output_path}"


if __name__ == "__main__":
config = ConvertConfig()
config.parse() # type: ignore[attr-defined]
convert_checkpoint(config)
Loading

0 comments on commit 517031f

Please sign in to comment.