Skip to content

Commit

Permalink
set default value of head in code
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 22, 2024
1 parent 05503a5 commit 5370c57
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 80 deletions.
4 changes: 2 additions & 2 deletions multimolecule/models/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class HeadConfig:
The activation function of the final prediction output.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
num_labels (`int`, *optional*, defaults to 2):
num_labels (`int`, *optional*, defaults to 1):
Number of labels to use in the last layer added to the model, typically for a classification task.
problem_type (`str`, *optional*):
Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
Expand All @@ -77,7 +77,7 @@ class HeadConfig:
bias: bool = True
act: Optional[str] = None
layer_norm_eps: float = 1e-12
num_labels: int = 2
num_labels: int = 1
problem_type: Optional[str] = None


Expand Down
19 changes: 16 additions & 3 deletions multimolecule/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ class ContactPredictionHead(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config.head
self.num_labels = config.head.num_labels
if self.config.hidden_size is None:
self.config.hidden_size = config.hidden_size
if self.config.num_labels is None:
self.config.num_labels = config.num_labels
if self.config.problem_type is None:
self.config.problem_type = config.problem_type
self.num_labels = self.config.num_labels
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.pad_token_id = config.pad_token_id
Expand Down Expand Up @@ -83,10 +89,11 @@ class MaskedLMHead(nn.Module):
def __init__(self, config: PretrainedConfig, weight: Optional[Tensor] = None):
super().__init__()
self.config = config.lm_head if hasattr(config, "lm_head") else config.head
if self.config.hidden_size is None:
self.config.hidden_size = config.hidden_size
self.num_labels = config.vocab_size
self.dropout = nn.Dropout(self.config.dropout)
self.transform = PredictionHeadTransform.build(self.config)

self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=False)
if weight is not None:
self.decoder.weight = weight
Expand All @@ -113,7 +120,13 @@ class ClassificationHead(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config.head
self.num_labels = config.head.num_labels
if self.config.hidden_size is None:
self.config.hidden_size = config.hidden_size
if self.config.num_labels is None:
self.config.num_labels = config.num_labels
if self.config.problem_type is None:
self.config.problem_type = config.problem_type
self.num_labels = self.config.num_labels
self.dropout = nn.Dropout(self.config.dropout)
self.transform = PredictionHeadTransform.build(self.config)
self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias)
Expand Down
14 changes: 2 additions & 12 deletions multimolecule/models/rnabert/configuration_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ def __init__(
):
if hidden_size is None:
hidden_size = num_attention_heads * multiple if multiple is not None else 120
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
Expand All @@ -105,5 +95,5 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
14 changes: 2 additions & 12 deletions multimolecule/models/rnafm/configuration_rnafm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,6 @@ def __init__(
lm_head=None,
**kwargs,
):
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
Expand All @@ -118,5 +108,5 @@ def __init__(
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
2 changes: 1 addition & 1 deletion multimolecule/models/rnafm/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def convert_checkpoint(convert_config):
"<null>",
"<mask>",
]
config = Config(num_labels=1)
config = Config()
config.architectures = ["RnaFmModel"]
config.vocab_size = len(vocab_list)

Expand Down
14 changes: 2 additions & 12 deletions multimolecule/models/rnamsm/configuration_rnamsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ def __init__(
lm_head=None,
**kwargs,
):
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
Expand All @@ -107,5 +97,5 @@ def __init__(
self.attention_type = attention_type
self.embed_positions_msa = embed_positions_msa
self.attention_bias = attention_bias
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
2 changes: 1 addition & 1 deletion multimolecule/models/rnamsm/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_
def convert_checkpoint(convert_config):
vocab_list = get_vocab_list()
original_vocab_list = ["<cls>", "<pad>", "<eos>", "<unk>", "A", "G", "C", "U", "X", "N", "-", "<mask>"]
config = Config(num_labels=1)
config = Config()
config.architectures = ["RnaMsmModel"]
config.vocab_size = len(vocab_list)

Expand Down
14 changes: 2 additions & 12 deletions multimolecule/models/splicebert/configuration_splicebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,6 @@ def __init__(
lm_head=None,
**kwargs,
):
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
Expand All @@ -100,5 +90,5 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
14 changes: 2 additions & 12 deletions multimolecule/models/utrbert/configuration_utrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,6 @@ def __init__(
lm_head=None,
**kwargs,
):
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
Expand All @@ -120,5 +110,5 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
14 changes: 2 additions & 12 deletions multimolecule/models/utrlm/configuration_utrlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,6 @@ def __init__(
supervised_head=None,
**kwargs,
):
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
Expand All @@ -129,7 +119,7 @@ def __init__(
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
self.structure_head = HeadConfig(**structure_head) if structure_head is not None else None
self.supervised_head = HeadConfig(**supervised_head) if supervised_head is not None else None
2 changes: 1 addition & 1 deletion multimolecule/models/utrlm/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_


def convert_checkpoint(convert_config):
config = chanfig.FlatDict(num_labels=1)
config = chanfig.FlatDict()
config.supervised_head = {"num_labels": 1}
if "4.1" in convert_config.checkpoint_path:
config.structure_head = {"num_labels": 3}
Expand Down

0 comments on commit 5370c57

Please sign in to comment.