Skip to content

Commit

Permalink
Create auto-resizing version of the Torrin network
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Jan 29, 2025
1 parent f71b09c commit 1eee9bb
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
from __future__ import annotations

from typing import Self

import torch
from torch.nn import Module, Transformer, Conv1d, Parameter, Linear, Flatten, Sigmoid


class Torrin(Module):
def __init__(self):
@classmethod
def new(cls, input_length: int = 3500) -> Self:
return cls(input_length=input_length)

def __init__(self, input_length: int):
super().__init__()
embedding_size = 16
self.input_length = input_length
self.embedding_layer = Conv1d(in_channels=1, out_channels=embedding_size, kernel_size=35, stride=35)
self.transformer = Transformer(d_model=embedding_size, dim_feedforward=16, batch_first=True, num_decoder_layers=1)
self.transformer = Transformer(d_model=embedding_size, dim_feedforward=16, batch_first=True,
num_decoder_layers=1)
self.class_embedding = Parameter(torch.randn([1, 1, embedding_size]))
self.flatten = Flatten()
self.classification_layer = Linear(in_features=16, out_features=1)
self.sigmoid = Sigmoid()

def forward(self, x):
x = x.reshape([-1, 1, 3500])
x = x.reshape([-1, 1, self.input_length])
x = self.embedding_layer(x)
x = torch.permute(x, (0, 2, 1))
expanded_class_embedding = self.class_embedding.expand(x.size(0), -1, -1)
Expand Down
31 changes: 31 additions & 0 deletions tests/unit_tests/test_torran_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch

from qusi.internal.torrin_model import Torrin


def test_lengths_give_correct_output_size():
torrin50 = Torrin.new(input_length=50)

output50 = torrin50(torch.arange(50, dtype=torch.float32).reshape([1, 50]))

assert output50.shape == torch.Size([1])

torrin1000 = Torrin.new(input_length=1000)

output1000 = torrin1000(torch.arange(1000, dtype=torch.float32).reshape([1, 1000]))

assert output1000.shape == torch.Size([1])

torrin3673 = Torrin.new(input_length=3673)

output3673 = torrin3673(torch.arange(3673, dtype=torch.float32).reshape([1, 3673]))

assert output3673.shape == torch.Size([1])

torrin100000 = Torrin.new(input_length=100000)

output100000 = torrin100000(
torch.arange(100000, dtype=torch.float32).reshape([1, 100000])
)

assert output100000.shape == torch.Size([1])

0 comments on commit 1eee9bb

Please sign in to comment.