From dbb628e79dd6b89abf0351cebd184f01bcd174b4 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Mon, 22 Apr 2024 21:09:07 +0800 Subject: [PATCH] set default value of head in code Signed-off-by: Zhiyuan Chen --- multimolecule/models/modeling_utils.py | 19 ++++++++++++++++--- .../models/rnabert/configuration_rnabert.py | 14 ++------------ .../models/rnafm/configuration_rnafm.py | 14 ++------------ .../models/rnamsm/configuration_rnamsm.py | 14 ++------------ .../splicebert/configuration_splicebert.py | 14 ++------------ .../models/utrbert/configuration_utrbert.py | 14 ++------------ .../models/utrlm/configuration_utrlm.py | 14 ++------------ 7 files changed, 28 insertions(+), 75 deletions(-) diff --git a/multimolecule/models/modeling_utils.py b/multimolecule/models/modeling_utils.py index ca504be5..bc8c8e4e 100644 --- a/multimolecule/models/modeling_utils.py +++ b/multimolecule/models/modeling_utils.py @@ -20,7 +20,13 @@ class ContactPredictionHead(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() self.config = config.head - self.num_labels = config.head.num_labels + if self.config.hidden_size is None: + self.config.hidden_size = config.hidden_size + if self.config.num_labels is None: + self.config.num_labels = config.num_labels + if self.config.problem_type is None: + self.config.problem_type = config.problem_type + self.num_labels = self.config.num_labels self.bos_token_id = config.bos_token_id self.eos_token_id = config.eos_token_id self.pad_token_id = config.pad_token_id @@ -83,10 +89,11 @@ class MaskedLMHead(nn.Module): def __init__(self, config: PretrainedConfig, weight: Optional[Tensor] = None): super().__init__() self.config = config.lm_head if hasattr(config, "lm_head") else config.head + if self.config.hidden_size is None: + self.config.hidden_size = config.hidden_size self.num_labels = config.vocab_size self.dropout = nn.Dropout(self.config.dropout) self.transform = PredictionHeadTransform.build(self.config) - self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=False) if weight is not None: self.decoder.weight = weight @@ -113,7 +120,13 @@ class ClassificationHead(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() self.config = config.head - self.num_labels = config.head.num_labels + if self.config.hidden_size is None: + self.config.hidden_size = config.hidden_size + if self.config.num_labels is None: + self.config.num_labels = config.num_labels + if self.config.problem_type is None: + self.config.problem_type = config.problem_type + self.num_labels = self.config.num_labels self.dropout = nn.Dropout(self.config.dropout) self.transform = PredictionHeadTransform.build(self.config) self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias) diff --git a/multimolecule/models/rnabert/configuration_rnabert.py b/multimolecule/models/rnabert/configuration_rnabert.py index 23ceef67..4d22e72e 100644 --- a/multimolecule/models/rnabert/configuration_rnabert.py +++ b/multimolecule/models/rnabert/configuration_rnabert.py @@ -78,16 +78,6 @@ def __init__( ): if hidden_size is None: hidden_size = num_attention_heads * multiple if multiple is not None else 120 - if head is None: - head = {} - head.setdefault("hidden_size", hidden_size) - if "problem_type" in kwargs: - head.setdefault("problem_type", kwargs["problem_type"]) - if "num_labels" in kwargs: - head.setdefault("num_labels", kwargs["num_labels"]) - if lm_head is None: - lm_head = {} - lm_head.setdefault("hidden_size", hidden_size) super().__init__(**kwargs) self.vocab_size = vocab_size @@ -105,5 +95,5 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.use_cache = use_cache - self.head = HeadConfig(**head) - self.lm_head = MaskedLMHeadConfig(**lm_head) + self.head = HeadConfig(**head if head is not None else {}) + self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) diff --git a/multimolecule/models/rnafm/configuration_rnafm.py b/multimolecule/models/rnafm/configuration_rnafm.py index 425e5c82..d372a7ea 100644 --- a/multimolecule/models/rnafm/configuration_rnafm.py +++ b/multimolecule/models/rnafm/configuration_rnafm.py @@ -91,16 +91,6 @@ def __init__( lm_head=None, **kwargs, ): - if head is None: - head = {} - head.setdefault("hidden_size", hidden_size) - if "problem_type" in kwargs: - head.setdefault("problem_type", kwargs["problem_type"]) - if "num_labels" in kwargs: - head.setdefault("num_labels", kwargs["num_labels"]) - if lm_head is None: - lm_head = {} - lm_head.setdefault("hidden_size", hidden_size) super().__init__(**kwargs) self.vocab_size = vocab_size @@ -118,5 +108,5 @@ def __init__( self.use_cache = use_cache self.emb_layer_norm_before = emb_layer_norm_before self.token_dropout = token_dropout - self.head = HeadConfig(**head) - self.lm_head = MaskedLMHeadConfig(**lm_head) + self.head = HeadConfig(**head if head is not None else {}) + self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) diff --git a/multimolecule/models/rnamsm/configuration_rnamsm.py b/multimolecule/models/rnamsm/configuration_rnamsm.py index 81e6566d..2232e8ac 100644 --- a/multimolecule/models/rnamsm/configuration_rnamsm.py +++ b/multimolecule/models/rnamsm/configuration_rnamsm.py @@ -78,16 +78,6 @@ def __init__( lm_head=None, **kwargs, ): - if head is None: - head = {} - head.setdefault("hidden_size", hidden_size) - if "problem_type" in kwargs: - head.setdefault("problem_type", kwargs["problem_type"]) - if "num_labels" in kwargs: - head.setdefault("num_labels", kwargs["num_labels"]) - if lm_head is None: - lm_head = {} - lm_head.setdefault("hidden_size", hidden_size) super().__init__(**kwargs) self.vocab_size = vocab_size @@ -107,5 +97,5 @@ def __init__( self.attention_type = attention_type self.embed_positions_msa = embed_positions_msa self.attention_bias = attention_bias - self.head = HeadConfig(**head) - self.lm_head = MaskedLMHeadConfig(**lm_head) + self.head = HeadConfig(**head if head is not None else {}) + self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) diff --git a/multimolecule/models/splicebert/configuration_splicebert.py b/multimolecule/models/splicebert/configuration_splicebert.py index dd7cb838..6b70357d 100644 --- a/multimolecule/models/splicebert/configuration_splicebert.py +++ b/multimolecule/models/splicebert/configuration_splicebert.py @@ -74,16 +74,6 @@ def __init__( lm_head=None, **kwargs, ): - if head is None: - head = {} - head.setdefault("hidden_size", hidden_size) - if "problem_type" in kwargs: - head.setdefault("problem_type", kwargs["problem_type"]) - if "num_labels" in kwargs: - head.setdefault("num_labels", kwargs["num_labels"]) - if lm_head is None: - lm_head = {} - lm_head.setdefault("hidden_size", hidden_size) super().__init__(**kwargs) self.vocab_size = vocab_size @@ -100,5 +90,5 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.use_cache = use_cache - self.head = HeadConfig(**head) - self.lm_head = MaskedLMHeadConfig(**lm_head) + self.head = HeadConfig(**head if head is not None else {}) + self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) diff --git a/multimolecule/models/utrbert/configuration_utrbert.py b/multimolecule/models/utrbert/configuration_utrbert.py index f0b32c7e..854e193c 100644 --- a/multimolecule/models/utrbert/configuration_utrbert.py +++ b/multimolecule/models/utrbert/configuration_utrbert.py @@ -93,16 +93,6 @@ def __init__( lm_head=None, **kwargs, ): - if head is None: - head = {} - head.setdefault("hidden_size", hidden_size) - if "problem_type" in kwargs: - head.setdefault("problem_type", kwargs["problem_type"]) - if "num_labels" in kwargs: - head.setdefault("num_labels", kwargs["num_labels"]) - if lm_head is None: - lm_head = {} - lm_head.setdefault("hidden_size", hidden_size) super().__init__(**kwargs) self.vocab_size = vocab_size @@ -120,5 +110,5 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.use_cache = use_cache - self.head = HeadConfig(**head) - self.lm_head = MaskedLMHeadConfig(**lm_head) + self.head = HeadConfig(**head if head is not None else {}) + self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) diff --git a/multimolecule/models/utrlm/configuration_utrlm.py b/multimolecule/models/utrlm/configuration_utrlm.py index ba80481d..b965b7c1 100644 --- a/multimolecule/models/utrlm/configuration_utrlm.py +++ b/multimolecule/models/utrlm/configuration_utrlm.py @@ -102,16 +102,6 @@ def __init__( supervised_head=None, **kwargs, ): - if head is None: - head = {} - head.setdefault("hidden_size", hidden_size) - if "problem_type" in kwargs: - head.setdefault("problem_type", kwargs["problem_type"]) - if "num_labels" in kwargs: - head.setdefault("num_labels", kwargs["num_labels"]) - if lm_head is None: - lm_head = {} - lm_head.setdefault("hidden_size", hidden_size) super().__init__(**kwargs) self.vocab_size = vocab_size @@ -129,7 +119,7 @@ def __init__( self.use_cache = use_cache self.emb_layer_norm_before = emb_layer_norm_before self.token_dropout = token_dropout - self.head = HeadConfig(**head) - self.lm_head = MaskedLMHeadConfig(**lm_head) + self.head = HeadConfig(**head if head is not None else {}) + self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) self.structure_head = HeadConfig(**structure_head) if structure_head is not None else None self.supervised_head = HeadConfig(**supervised_head) if supervised_head is not None else None