Skip to content

Commit

Permalink
improve type hints in modeling_utils
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 3, 2024
1 parent d0017ed commit c12d569
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions multimolecule/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@

import torch
from chanfig import ConfigRegistry
from torch import nn
from torch import nn, Tensor
from torch.nn import functional as F
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput


class MaskedLMHead(nn.Module):
"""Head for masked language modeling."""

def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__()
if "proj_head_mode" not in dir(config) or config.proj_head_mode is None:
config.proj_head_mode = "none"
Expand All @@ -22,19 +23,19 @@ def __init__(self, config):

def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
input_ids: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
encoder_hidden_states: Optional[Tensor] = None,
encoder_attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
) -> Union[Tuple[Tensor], MaskedLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
Expand Down Expand Up @@ -76,7 +77,7 @@ class SequenceClassificationHead(nn.Module):

num_labels: int

def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__()
if "proj_head_mode" not in dir(config) or config.proj_head_mode is None:
config.proj_head_mode = "none"
Expand All @@ -91,7 +92,7 @@ def __init__(self, config):
self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=False)

def forward(
self, outputs, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None
self, outputs, labels: Optional[Tensor] = None, return_dict: Optional[bool] = None
) -> Union[Tuple, SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
sequence_output = outputs.last_hidden_state if return_dict else outputs[0]
Expand Down Expand Up @@ -135,7 +136,7 @@ class TokenClassificationHead(nn.Module):

num_labels: int

def __init__(self, config):
def __init__(self, config: PretrainedConfig):
if "proj_head_mode" not in dir(config) or config.proj_head_mode is None:
config.proj_head_mode = "none"
super().__init__()
Expand All @@ -150,7 +151,7 @@ def __init__(self, config):
self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=False)

def forward(
self, outputs, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None
self, outputs, labels: Optional[Tensor] = None, return_dict: Optional[bool] = None
) -> Union[Tuple, TokenClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
token_output = outputs.pooled_output if return_dict else outputs[1]
Expand Down Expand Up @@ -194,7 +195,7 @@ def forward(

@PredictionHeadTransform.register("nonlinear")
class NonLinearTransform(nn.Module):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
Expand All @@ -203,7 +204,7 @@ def __init__(self, config):
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
Expand All @@ -212,18 +213,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

@PredictionHeadTransform.register("linear")
class LinearTransform(nn.Module):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states


@PredictionHeadTransform.register("none")
class IdentityTransform(nn.Identity):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__()

0 comments on commit c12d569

Please sign in to comment.