Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try a bidirectional model that processes right to left. #116

Open
matsen opened this issue Feb 13, 2025 · 0 comments
Open

Try a bidirectional model that processes right to left. #116

matsen opened this issue Feb 13, 2025 · 0 comments
Assignees

Comments

@matsen
Copy link
Contributor

matsen commented Feb 13, 2025

The idea being that we will profit from J alignment as well.

Here's Claude's suggestion:

class BidirectionalTransformerBinarySelectionModel(AbstractBinarySelectionModel):
    def __init__(
        self,
        nhead: int,
        d_model_per_head: int,
        dim_feedforward: int,
        layer_count: int,
        dropout_prob: float = 0.5,
        output_dim: int = 1,
    ):
        super().__init__()
        self.d_model_per_head = d_model_per_head
        self.d_model = d_model_per_head * nhead
        self.nhead = nhead
        self.dim_feedforward = dim_feedforward
        
        # Forward direction components
        self.forward_pos_encoder = PositionalEncoding(self.d_model, dropout_prob)
        self.forward_amino_acid_embedding = nn.Embedding(MAX_AMBIG_AA_IDX + 1, self.d_model)
        self.forward_encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True,
        )
        self.forward_encoder = nn.TransformerEncoder(self.forward_encoder_layer, layer_count)
        
        # Reverse direction components
        self.reverse_pos_encoder = PositionalEncoding(self.d_model, dropout_prob)
        self.reverse_amino_acid_embedding = nn.Embedding(MAX_AMBIG_AA_IDX + 1, self.d_model)
        self.reverse_encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True,
        )
        self.reverse_encoder = nn.TransformerEncoder(self.reverse_encoder_layer, layer_count)
        
        # Output layers
        self.combine_features = nn.Linear(2 * self.d_model, self.d_model)
        self.output = nn.Linear(self.d_model, output_dim)
        
        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.combine_features.bias.data.zero_()
        self.combine_features.weight.data.uniform_(-initrange, initrange)
        self.output.bias.data.zero_()
        self.output.weight.data.uniform_(-initrange, initrange)

    def represent_sequence(self, indices: Tensor, mask: Tensor, 
                         embedding: nn.Embedding, pos_encoder: PositionalEncoding, 
                         encoder: nn.TransformerEncoder) -> Tensor:
        """Process sequence through one direction of the model."""
        embedded = embedding(indices) * math.sqrt(self.d_model)
        embedded = pos_encoder(embedded.permute(1, 0, 2)).permute(1, 0, 2)
        return encoder(embedded, src_key_padding_mask=~mask)

    def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
        batch_size, seq_len = amino_acid_indices.shape
        seq_lengths = mask.sum(dim=1)
        
        # Forward direction - normal processing
        forward_repr = self.represent_sequence(
            amino_acid_indices, mask,
            self.forward_amino_acid_embedding,
            self.forward_pos_encoder,
            self.forward_encoder
        )
        
        # Reverse direction - flip sequences and masks
        reversed_indices = torch.zeros_like(amino_acid_indices)
        reversed_mask = torch.zeros_like(mask)
        
        for i in range(batch_size):
            length = seq_lengths[i]
            # Reverse and left-pad the sequence
            reversed_indices[i, -length:] = amino_acid_indices[i, :length].flip(0)
            reversed_mask[i, -length:] = mask[i, :length].flip(0)
            
        reverse_repr = self.represent_sequence(
            reversed_indices, reversed_mask,
            self.reverse_amino_acid_embedding,
            self.reverse_pos_encoder,
            self.reverse_encoder
        )
        
        # Un-reverse the representations to align with forward direction
        aligned_reverse_repr = torch.zeros_like(reverse_repr)
        for i in range(batch_size):
            length = seq_lengths[i]
            aligned_reverse_repr[i, :length] = reverse_repr[i, -length:].flip(0)
            
        # Combine features
        combined = torch.cat([forward_repr, aligned_reverse_repr], dim=-1)
        combined = self.combine_features(combined)
        
        # Output layer
        return self.output(combined).squeeze(-1)

    @property
    def hyperparameters(self):
        return {
            "nhead": self.nhead,
            "d_model_per_head": self.d_model_per_head,
            "dim_feedforward": self.dim_feedforward,
            "layer_count": self.forward_encoder.num_layers,
            "dropout_prob": self.forward_pos_encoder.dropout.p,
            "output_dim": self.output.out_features,
        }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants