Skip to content

Commit

Permalink
Add conversion example
Browse files Browse the repository at this point in the history
  • Loading branch information
Mmasoud1 committed Mar 28, 2024
1 parent 7fbde87 commit 4c1ff39
Show file tree
Hide file tree
Showing 7 changed files with 687 additions and 0 deletions.
Binary file modified css/images/BrainchopMoreRobustModels.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
69 changes: 69 additions & 0 deletions py2tfjs/conversion_example/blendbatchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# https://github.com/MIPT-Oulu/pytorch_bn_fusion/blob/master/bn_fusion.py
import torch
import torch.nn as nn


def fuse_bn_sequential(block):
"""
This function takes a sequential block and fuses the batch normalization with convolution
:param model: nn.Sequential. Source resnet model
:return: nn.Sequential. Converted block
"""
if not isinstance(block, nn.Sequential):
return block
stack = []
for m in block.children():
if isinstance(m, nn.BatchNorm3d):
if isinstance(stack[-1], nn.Conv3d):
bn_st_dict = m.state_dict()
conv_st_dict = stack[-1].state_dict()

# BatchNorm params
eps = m.eps
mu = bn_st_dict["running_mean"]
var = bn_st_dict["running_var"]
gamma = bn_st_dict["weight"]

if "bias" in bn_st_dict:
beta = bn_st_dict["bias"]
else:
beta = torch.zeros(gamma.size(0)).float().to(gamma.device)

# Conv params
W = conv_st_dict["weight"]
if "bias" in conv_st_dict:
bias = conv_st_dict["bias"]
else:
bias = torch.zeros(W.size(0)).float().to(gamma.device)

denom = torch.sqrt(var + eps)
b = beta - gamma.mul(mu).div(denom)
A = gamma.div(denom)
bias *= A
A = A.expand_as(W.transpose(0, -1)).transpose(0, -1)

W.mul_(A)
bias.add_(b)

stack[-1].weight.data.copy_(W)
if stack[-1].bias is None:
stack[-1].bias = torch.nn.Parameter(bias)
else:
stack[-1].bias.data.copy_(bias)

else:
stack.append(m)

if len(stack) > 1:
return nn.Sequential(*stack)
else:
return stack[0]


def fuse_bn_recursively(model):
for module_name in model._modules:
model._modules[module_name] = fuse_bn_sequential(model._modules[module_name])
if len(model._modules[module_name]._modules) > 0:
fuse_bn_recursively(model._modules[module_name])

return model
56 changes: 56 additions & 0 deletions py2tfjs/conversion_example/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch

from blendbatchnorm import fuse_bn_recursively
from meshnet2tfjs import meshnet2tfjs

from meshnet import (
MeshNet,
enMesh_checkpoint,
)

device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device_name)


def preprocess_image(img, qmin=0.01, qmax=0.99):
"""Unit interval preprocessing"""
img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin))
return img


# def preprocess_image(img):
# """Unit interval preprocessing"""
# img = (img - img.min()) / (img.max() - img.min())
# return img


# specify how many classes does the model predict
n_classes = 3
# specify the architecture
config_file = "modelAE.json"
# how many channels does the saved model have
model_channels = 15
# path to the saved model
model_path = "model.pth"
# tfjs model output directory
tfjs_model_dir = "model_tfjs"

meshnet_model = enMesh_checkpoint(
in_channels=1,
n_classes=n_classes,
channels=model_channels,
config_file=config_file,
)

checkpoint = torch.load(model_path)
meshnet_model.load_state_dict(checkpoint)

meshnet_model.eval()

meshnet_model.to(device)
mnm = fuse_bn_recursively(meshnet_model)
del meshnet_model
mnm.model.eval()


meshnet2tfjs(mnm, tfjs_model_dir)
138 changes: 138 additions & 0 deletions py2tfjs/conversion_example/meshnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint_sequential
import json


def set_channel_num(config, in_channels, n_classes, channels):
"""
Takes a configuration json for a convolutional neural network of MeshNet architecture and changes it to have the specified number of input channels, output classes, and number of channels that each layer except the input and output layers have.
Args:
config (dict): The configuration json for the network.
in_channels (int): The number of input channels.
n_classes (int): The number of output classes.
channels (int): The number of channels that each layer except the input and output layers will have.
Returns:
dict: The updated configuration json.
"""
# input layer
config["layers"][0]["in_channels"] = in_channels
config["layers"][0]["out_channels"] = channels

# output layer
config["layers"][-1]["in_channels"] = channels
config["layers"][-1]["out_channels"] = n_classes

# hidden layers
for layer in config["layers"][1:-1]:
layer["in_channels"] = layer["out_channels"] = channels

return config


def construct_layer(dropout_p=0, bnorm=True, gelu=False, *args, **kwargs):
"""Constructs a configurable Convolutional block with Batch Normalization and Dropout.
Args:
dropout_p (float): Dropout probability. Default is 0.
bnorm (bool): Whether to include batch normalization. Default is True.
gelu (bool): Whether to use GELU activation. Default is False.
*args: Additional positional arguments to pass to nn.Conv3d.
**kwargs: Additional keyword arguments to pass to nn.Conv3d.
Returns:
nn.Sequential: A sequential container of Convolutional block with optional Batch Normalization and Dropout.
"""
layers = []
layers.append(nn.Conv3d(*args, **kwargs))
if bnorm:
# track_running_stats=False is needed to run the forward mode AD
layers.append(
nn.BatchNorm3d(kwargs["out_channels"], track_running_stats=True)
)
layers.append(nn.ELU(inplace=True) if gelu else nn.ReLU(inplace=True))
if dropout_p > 0:
layers.append(nn.Dropout3d(dropout_p))
return nn.Sequential(*layers)


def init_weights(model):
"""Set weights to be xavier normal for all Convs"""
for m in model.modules():
if isinstance(
m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)
):
# nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain("relu"))
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)


class MeshNet(nn.Module):
"""Configurable MeshNet from https://arxiv.org/pdf/1612.00940.pdf"""

def __init__(self, in_channels, n_classes, channels, config_file, fat=None):
"""Init"""
with open(config_file, "r") as f:
config = set_channel_num(
json.load(f), in_channels, n_classes, channels
)

if fat is not None:
chn = int(channels * 1.5)
if fat in {"i", "io"}:
config["layers"][0]["out_channels"] = chn
config["layers"][1]["in_channels"] = chn
if fat == "io":
config["layers"][-1]["in_channels"] = chn
config["layers"][-2]["out_channels"] = chn
if fat == "b":
config["layers"][3]["out_channels"] = chn
config["layers"][4]["in_channels"] = chn

super(MeshNet, self).__init__()

layers = [
construct_layer(
dropout_p=config["dropout_p"],
bnorm=config["bnorm"],
gelu=config["gelu"],
**block_kwargs,
)
for block_kwargs in config["layers"]
]
layers[-1] = layers[-1][0]
self.model = nn.Sequential(*layers)
init_weights(self.model)

def forward(self, x):
"""Forward pass"""
x = self.model(x)
return x


class enMesh_checkpoint(MeshNet):
def train_forward(self, x):
y = x
y.requires_grad_()
y = checkpoint_sequential(
self.model, len(self.model), y, preserve_rng_state=False
)
return y

def eval_forward(self, x):
"""Forward pass"""
self.model.eval()
with torch.inference_mode():
x = self.model(x)
return x

def forward(self, x):
if self.training:
return self.train_forward(x)
else:
return self.eval_forward(x)
Loading

0 comments on commit 4c1ff39

Please sign in to comment.