Skip to content

Commit

Permalink
set default value of head in code
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 22, 2024
1 parent f887e64 commit dbb628e
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 75 deletions.
19 changes: 16 additions & 3 deletions multimolecule/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ class ContactPredictionHead(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config.head
self.num_labels = config.head.num_labels
if self.config.hidden_size is None:
self.config.hidden_size = config.hidden_size
if self.config.num_labels is None:
self.config.num_labels = config.num_labels
if self.config.problem_type is None:
self.config.problem_type = config.problem_type
self.num_labels = self.config.num_labels
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.pad_token_id = config.pad_token_id
Expand Down Expand Up @@ -83,10 +89,11 @@ class MaskedLMHead(nn.Module):
def __init__(self, config: PretrainedConfig, weight: Optional[Tensor] = None):
super().__init__()
self.config = config.lm_head if hasattr(config, "lm_head") else config.head
if self.config.hidden_size is None:
self.config.hidden_size = config.hidden_size
self.num_labels = config.vocab_size
self.dropout = nn.Dropout(self.config.dropout)
self.transform = PredictionHeadTransform.build(self.config)

self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=False)
if weight is not None:
self.decoder.weight = weight
Expand All @@ -113,7 +120,13 @@ class ClassificationHead(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config.head
self.num_labels = config.head.num_labels
if self.config.hidden_size is None:
self.config.hidden_size = config.hidden_size
if self.config.num_labels is None:
self.config.num_labels = config.num_labels
if self.config.problem_type is None:
self.config.problem_type = config.problem_type
self.num_labels = self.config.num_labels
self.dropout = nn.Dropout(self.config.dropout)
self.transform = PredictionHeadTransform.build(self.config)
self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias)
Expand Down
14 changes: 2 additions & 12 deletions multimolecule/models/rnabert/configuration_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ def __init__(
):
if hidden_size is None:
hidden_size = num_attention_heads * multiple if multiple is not None else 120
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
Expand All @@ -105,5 +95,5 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
14 changes: 2 additions & 12 deletions multimolecule/models/rnafm/configuration_rnafm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,6 @@ def __init__(
lm_head=None,
**kwargs,
):
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
Expand All @@ -118,5 +108,5 @@ def __init__(
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
14 changes: 2 additions & 12 deletions multimolecule/models/rnamsm/configuration_rnamsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ def __init__(
lm_head=None,
**kwargs,
):
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
Expand All @@ -107,5 +97,5 @@ def __init__(
self.attention_type = attention_type
self.embed_positions_msa = embed_positions_msa
self.attention_bias = attention_bias
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
14 changes: 2 additions & 12 deletions multimolecule/models/splicebert/configuration_splicebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,6 @@ def __init__(
lm_head=None,
**kwargs,
):
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
Expand All @@ -100,5 +90,5 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
14 changes: 2 additions & 12 deletions multimolecule/models/utrbert/configuration_utrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,6 @@ def __init__(
lm_head=None,
**kwargs,
):
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
Expand All @@ -120,5 +110,5 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
14 changes: 2 additions & 12 deletions multimolecule/models/utrlm/configuration_utrlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,6 @@ def __init__(
supervised_head=None,
**kwargs,
):
if head is None:
head = {}
head.setdefault("hidden_size", hidden_size)
if "problem_type" in kwargs:
head.setdefault("problem_type", kwargs["problem_type"])
if "num_labels" in kwargs:
head.setdefault("num_labels", kwargs["num_labels"])
if lm_head is None:
lm_head = {}
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(**kwargs)

self.vocab_size = vocab_size
Expand All @@ -129,7 +119,7 @@ def __init__(
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
self.structure_head = HeadConfig(**structure_head) if structure_head is not None else None
self.supervised_head = HeadConfig(**supervised_head) if supervised_head is not None else None

0 comments on commit dbb628e

Please sign in to comment.