From 1eee9bb7e890f3162e68aa6d6cb6df2f2a34a546 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Fri, 17 Jan 2025 17:12:34 -0500 Subject: [PATCH] Create auto-resizing version of the Torrin network --- .../{torran_model.py => torrin_model.py} | 14 +++++++-- tests/unit_tests/test_torran_model.py | 31 +++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) rename src/qusi/internal/{torran_model.py => torrin_model.py} (76%) create mode 100644 tests/unit_tests/test_torran_model.py diff --git a/src/qusi/internal/torran_model.py b/src/qusi/internal/torrin_model.py similarity index 76% rename from src/qusi/internal/torran_model.py rename to src/qusi/internal/torrin_model.py index d3830f0..c1dd804 100644 --- a/src/qusi/internal/torran_model.py +++ b/src/qusi/internal/torrin_model.py @@ -1,22 +1,30 @@ from __future__ import annotations +from typing import Self + import torch from torch.nn import Module, Transformer, Conv1d, Parameter, Linear, Flatten, Sigmoid class Torrin(Module): - def __init__(self): + @classmethod + def new(cls, input_length: int = 3500) -> Self: + return cls(input_length=input_length) + + def __init__(self, input_length: int): super().__init__() embedding_size = 16 + self.input_length = input_length self.embedding_layer = Conv1d(in_channels=1, out_channels=embedding_size, kernel_size=35, stride=35) - self.transformer = Transformer(d_model=embedding_size, dim_feedforward=16, batch_first=True, num_decoder_layers=1) + self.transformer = Transformer(d_model=embedding_size, dim_feedforward=16, batch_first=True, + num_decoder_layers=1) self.class_embedding = Parameter(torch.randn([1, 1, embedding_size])) self.flatten = Flatten() self.classification_layer = Linear(in_features=16, out_features=1) self.sigmoid = Sigmoid() def forward(self, x): - x = x.reshape([-1, 1, 3500]) + x = x.reshape([-1, 1, self.input_length]) x = self.embedding_layer(x) x = torch.permute(x, (0, 2, 1)) expanded_class_embedding = self.class_embedding.expand(x.size(0), -1, -1) diff --git a/tests/unit_tests/test_torran_model.py b/tests/unit_tests/test_torran_model.py new file mode 100644 index 0000000..2d4a67a --- /dev/null +++ b/tests/unit_tests/test_torran_model.py @@ -0,0 +1,31 @@ +import torch + +from qusi.internal.torrin_model import Torrin + + +def test_lengths_give_correct_output_size(): + torrin50 = Torrin.new(input_length=50) + + output50 = torrin50(torch.arange(50, dtype=torch.float32).reshape([1, 50])) + + assert output50.shape == torch.Size([1]) + + torrin1000 = Torrin.new(input_length=1000) + + output1000 = torrin1000(torch.arange(1000, dtype=torch.float32).reshape([1, 1000])) + + assert output1000.shape == torch.Size([1]) + + torrin3673 = Torrin.new(input_length=3673) + + output3673 = torrin3673(torch.arange(3673, dtype=torch.float32).reshape([1, 3673])) + + assert output3673.shape == torch.Size([1]) + + torrin100000 = Torrin.new(input_length=100000) + + output100000 = torrin100000( + torch.arange(100000, dtype=torch.float32).reshape([1, 100000]) + ) + + assert output100000.shape == torch.Size([1]) \ No newline at end of file