We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The idea being that we will profit from J alignment as well.
Here's Claude's suggestion:
class BidirectionalTransformerBinarySelectionModel(AbstractBinarySelectionModel): def __init__( self, nhead: int, d_model_per_head: int, dim_feedforward: int, layer_count: int, dropout_prob: float = 0.5, output_dim: int = 1, ): super().__init__() self.d_model_per_head = d_model_per_head self.d_model = d_model_per_head * nhead self.nhead = nhead self.dim_feedforward = dim_feedforward # Forward direction components self.forward_pos_encoder = PositionalEncoding(self.d_model, dropout_prob) self.forward_amino_acid_embedding = nn.Embedding(MAX_AMBIG_AA_IDX + 1, self.d_model) self.forward_encoder_layer = nn.TransformerEncoderLayer( d_model=self.d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True, ) self.forward_encoder = nn.TransformerEncoder(self.forward_encoder_layer, layer_count) # Reverse direction components self.reverse_pos_encoder = PositionalEncoding(self.d_model, dropout_prob) self.reverse_amino_acid_embedding = nn.Embedding(MAX_AMBIG_AA_IDX + 1, self.d_model) self.reverse_encoder_layer = nn.TransformerEncoderLayer( d_model=self.d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True, ) self.reverse_encoder = nn.TransformerEncoder(self.reverse_encoder_layer, layer_count) # Output layers self.combine_features = nn.Linear(2 * self.d_model, self.d_model) self.output = nn.Linear(self.d_model, output_dim) self.init_weights() def init_weights(self) -> None: initrange = 0.1 self.combine_features.bias.data.zero_() self.combine_features.weight.data.uniform_(-initrange, initrange) self.output.bias.data.zero_() self.output.weight.data.uniform_(-initrange, initrange) def represent_sequence(self, indices: Tensor, mask: Tensor, embedding: nn.Embedding, pos_encoder: PositionalEncoding, encoder: nn.TransformerEncoder) -> Tensor: """Process sequence through one direction of the model.""" embedded = embedding(indices) * math.sqrt(self.d_model) embedded = pos_encoder(embedded.permute(1, 0, 2)).permute(1, 0, 2) return encoder(embedded, src_key_padding_mask=~mask) def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: batch_size, seq_len = amino_acid_indices.shape seq_lengths = mask.sum(dim=1) # Forward direction - normal processing forward_repr = self.represent_sequence( amino_acid_indices, mask, self.forward_amino_acid_embedding, self.forward_pos_encoder, self.forward_encoder ) # Reverse direction - flip sequences and masks reversed_indices = torch.zeros_like(amino_acid_indices) reversed_mask = torch.zeros_like(mask) for i in range(batch_size): length = seq_lengths[i] # Reverse and left-pad the sequence reversed_indices[i, -length:] = amino_acid_indices[i, :length].flip(0) reversed_mask[i, -length:] = mask[i, :length].flip(0) reverse_repr = self.represent_sequence( reversed_indices, reversed_mask, self.reverse_amino_acid_embedding, self.reverse_pos_encoder, self.reverse_encoder ) # Un-reverse the representations to align with forward direction aligned_reverse_repr = torch.zeros_like(reverse_repr) for i in range(batch_size): length = seq_lengths[i] aligned_reverse_repr[i, :length] = reverse_repr[i, -length:].flip(0) # Combine features combined = torch.cat([forward_repr, aligned_reverse_repr], dim=-1) combined = self.combine_features(combined) # Output layer return self.output(combined).squeeze(-1) @property def hyperparameters(self): return { "nhead": self.nhead, "d_model_per_head": self.d_model_per_head, "dim_feedforward": self.dim_feedforward, "layer_count": self.forward_encoder.num_layers, "dropout_prob": self.forward_pos_encoder.dropout.p, "output_dim": self.output.out_features, }
The text was updated successfully, but these errors were encountered:
willdumm
No branches or pull requests
The idea being that we will profit from J alignment as well.
Here's Claude's suggestion:
The text was updated successfully, but these errors were encountered: