diff --git a/setup.cfg b/setup.cfg index 2f505fd..a88b85e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,9 +2,9 @@ name = LeibNetz version = 0.2.0 author = Jeff Rhoades, Larissa Heinrich -author_email = rhoadesj@hhmi.org, heinrichl@janelia.hhmi.org +author_email = rhoadesj@hhmi.org url = https://github.com/janelia-cellmap/LeibNetz -description = A lightweight and modular library for rapidly developing and constructing PyTorch models deep learning. +description = A lightweight and modular library for rapidly developing and constructing PyTorch models for deep learning. long_description = file: README.md long_description_content_type = text/markdown keywords = image-segmentation, convolutional-neural-networks, deep-learning, pytorch diff --git a/src/leibnetz/leibnet.py b/src/leibnetz/leibnet.py index 65fcb68..38aad90 100644 --- a/src/leibnetz/leibnet.py +++ b/src/leibnetz/leibnet.py @@ -1,11 +1,12 @@ -from typing import Iterable, Optional, Sequence, Tuple, Union +from typing import Iterable, Sequence, Tuple import networkx as nx from torch import device import torch from torch.nn import Module import numpy as np from leibnetz.nodes import Node -from funlib.learn.torch.models.conv4d import Conv4d + +# from funlib.learn.torch.models.conv4d import Conv4d # from model_opt.apis import optimize @@ -43,9 +44,8 @@ def __init__( self.apply( lambda m: ( torch.nn.init.kaiming_normal_(m.weight, mode="fan_out") - if isinstance(m, torch.nn.Conv2d) - or isinstance(m, torch.nn.Conv3d) - or isinstance(m, Conv4d) + if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv3d) + # or isinstance(m, Conv4d) else None ) ) @@ -53,9 +53,8 @@ def __init__( self.apply( lambda m: ( torch.nn.init.xavier_normal_(m.weight) - if isinstance(m, torch.nn.Conv2d) - or isinstance(m, torch.nn.Conv3d) - or isinstance(m, Conv4d) + if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv3d) + # or isinstance(m, Conv4d) else None ) ) @@ -63,9 +62,8 @@ def __init__( self.apply( lambda m: ( torch.nn.init.orthogonal_(m.weight) - if isinstance(m, torch.nn.Conv2d) - or isinstance(m, torch.nn.Conv3d) - or isinstance(m, Conv4d) + if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv3d) + # or isinstance(m, Conv4d) else None ) ) diff --git a/src/leibnetz/nodes/node_ops.py b/src/leibnetz/nodes/node_ops.py index 29d59f8..3f06905 100644 --- a/src/leibnetz/nodes/node_ops.py +++ b/src/leibnetz/nodes/node_ops.py @@ -1,6 +1,7 @@ from torch import nn import numpy as np -from funlib.learn.torch.models.conv4d import Conv4d + +# from funlib.learn.torch.models.conv4d import Conv4d class ConvPass(nn.Module): @@ -95,10 +96,13 @@ def __init__( self.dims = len(kernel_size) try: - conv = {2: nn.Conv2d, 3: nn.Conv3d, 4: Conv4d}[self.dims] + # TODO: Implement Conv4d + # conv = {2: nn.Conv2d, 3: nn.Conv3d, 4: Conv4d}[self.dims] + conv = {2: nn.Conv2d, 3: nn.Conv3d}[self.dims] except KeyError: raise ValueError( - f"Only 2D, 3D and 4D convolutions are supported, not {self.dims}D" + # f"Only 2D, 3D and 4D convolutions are supported, not {self.dims}D" + f"Only 2D and 3D convolutions are supported, not {self.dims}D" ) layers.append( diff --git a/src/leibnetz/nodes/resample_ops.py b/src/leibnetz/nodes/resample_ops.py index 63dcd55..ca32c1f 100644 --- a/src/leibnetz/nodes/resample_ops.py +++ b/src/leibnetz/nodes/resample_ops.py @@ -1,7 +1,6 @@ -import math from torch import nn -import torch -from funlib.learn.torch.models.conv4d import Conv4d + +# from funlib.learn.torch.models.conv4d import Conv4d from logging import getLogger @@ -52,7 +51,15 @@ def __init__( self.dims = len(kernel_sizes) - conv = {2: nn.Conv2d, 3: nn.Conv3d, 4: Conv4d}[self.dims] + try: + # TODO: Implement Conv4d + # conv = {2: nn.Conv2d, 3: nn.Conv3d, 4: Conv4d}[self.dims] + conv = {2: nn.Conv2d, 3: nn.Conv3d}[self.dims] + except KeyError: + raise ValueError( + # f"Only 2D, 3D and 4D convolutions are supported, not {self.dims}D" + f"Only 2D and 3D convolutions are supported, not {self.dims}D" + ) try: layers.append(