Skip to content

Commit

Permalink
feat: simplify, add skip items, resnet groups
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Dec 28, 2022
1 parent 05d390b commit 97deaa5
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 148 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# A-UNet

A library that provides building blocks to customize UNets, in PyTorch.
A toolbox that provides hackable building blocks for (1D/2D/3D) UNets, in PyTorch.

## Install
```bash
Expand Down
6 changes: 3 additions & 3 deletions a_unet/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __call__(self, *a, **ka):
class Inner:
def __init__(self):
self.args = a
self.kwargs = ka
self.__dict__.update(**ka)

def __call__(self, *b, **kb):
if override:
Expand Down Expand Up @@ -320,11 +320,11 @@ def forward(x: Tensor, features: Tensor) -> Tensor:
return Module([to_scale_shift, norm], forward)


def MergeAdd(*args, **kwargs):
def MergeAdd():
return Module([], lambda x, y, *_: x + y)


def MergeCat(dim: int, channels: int, **kwargs) -> nn.Module:
def MergeCat(dim: int, channels: int) -> nn.Module:
conv = Conv(dim=dim, in_channels=channels * 2, out_channels=channels, kernel_size=1)
return Module([conv], lambda x, y, *_: conv(torch.cat([x, y], dim=1)))

Expand Down
Loading

0 comments on commit 97deaa5

Please sign in to comment.