From e12ea4b1d60a628b0572720e882525bf015f5daa Mon Sep 17 00:00:00 2001 From: brianreicher Date: Tue, 7 Nov 2023 10:32:01 -0500 Subject: [PATCH] ACLSD model with 2 UNets --- src/raygun/torch/models/ACLSDModel.py | 37 ++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/src/raygun/torch/models/ACLSDModel.py b/src/raygun/torch/models/ACLSDModel.py index 90b251b..74e4d9d 100644 --- a/src/raygun/torch/models/ACLSDModel.py +++ b/src/raygun/torch/models/ACLSDModel.py @@ -1,3 +1,4 @@ +# from funlib.learn.torch.models import UNet, ConvPass from raygun.torch.networks import UNet from raygun.torch.networks.UNet import ConvPass import torch @@ -11,7 +12,15 @@ class ACLSDModel(torch.nn.Module): def __init__( self, - unet_kwargs={ + mt_unet_kwargs={ + "input_nc": 1, + "ngf": 12, + "fmap_inc_factor": 6, + "num_heads": 2, + "downsample_factors": [(2, 2, 2), (2, 2, 2), (2, 2, 2)], + "constant_upsample": True, + }, + ac_unet_kwargs={ "input_nc": 1, "ngf": 12, "fmap_inc_factor": 6, @@ -22,13 +31,22 @@ def __init__( ): super().__init__() - self.unet = UNet(**unet_kwargs) + self.mt_unet = UNet(**mt_unet_kwargs) + self.ac_unet = UNet(**ac_unet_kwargs) self.aff_head = ConvPass( - unet_kwargs["ngf"], num_affs, [[1, 1, 1]], activation="Sigmoid" + mt_unet_kwargs["ngf"], num_affs, [[1, 1, 1]], activation="Sigmoid" + ) + + self.lsd_head = ConvPass( # TODO: Make work without LSD + mt_unet_kwargs["ngf"], 10, [[1, 1, 1]], activation="Sigmoid" + ) + + self.ac_aff_head = ConvPass( + ac_unet_kwargs["ngf"], num_affs, [[1, 1, 1]], activation="Sigmoid" ) - self.output_arrays = ["pred_affs"] + self.output_arrays = ["pred_affs", "pred_lsds", "pred_affs_ac"] self.data_dict = {} def add_log(self, writer, step): @@ -51,7 +69,12 @@ def add_log(self, writer, step): def forward(self, raw): self.data_dict.update({"raw": raw.detach()}) - z = self.unet(raw) - affs = self.aff_head(z) + a = self.mt_unet(raw) + # conv passes for MTLSD + affs = self.aff_head(a) + lsds = self.lsd_head(a) + b = self.ac_unet(lsds) + # conv pass for ACLSD + affs_ac = self.ac_aff_head(b) - return affs + return affs, lsds, affs_ac