-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
687 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# https://github.com/MIPT-Oulu/pytorch_bn_fusion/blob/master/bn_fusion.py | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
def fuse_bn_sequential(block): | ||
""" | ||
This function takes a sequential block and fuses the batch normalization with convolution | ||
:param model: nn.Sequential. Source resnet model | ||
:return: nn.Sequential. Converted block | ||
""" | ||
if not isinstance(block, nn.Sequential): | ||
return block | ||
stack = [] | ||
for m in block.children(): | ||
if isinstance(m, nn.BatchNorm3d): | ||
if isinstance(stack[-1], nn.Conv3d): | ||
bn_st_dict = m.state_dict() | ||
conv_st_dict = stack[-1].state_dict() | ||
|
||
# BatchNorm params | ||
eps = m.eps | ||
mu = bn_st_dict["running_mean"] | ||
var = bn_st_dict["running_var"] | ||
gamma = bn_st_dict["weight"] | ||
|
||
if "bias" in bn_st_dict: | ||
beta = bn_st_dict["bias"] | ||
else: | ||
beta = torch.zeros(gamma.size(0)).float().to(gamma.device) | ||
|
||
# Conv params | ||
W = conv_st_dict["weight"] | ||
if "bias" in conv_st_dict: | ||
bias = conv_st_dict["bias"] | ||
else: | ||
bias = torch.zeros(W.size(0)).float().to(gamma.device) | ||
|
||
denom = torch.sqrt(var + eps) | ||
b = beta - gamma.mul(mu).div(denom) | ||
A = gamma.div(denom) | ||
bias *= A | ||
A = A.expand_as(W.transpose(0, -1)).transpose(0, -1) | ||
|
||
W.mul_(A) | ||
bias.add_(b) | ||
|
||
stack[-1].weight.data.copy_(W) | ||
if stack[-1].bias is None: | ||
stack[-1].bias = torch.nn.Parameter(bias) | ||
else: | ||
stack[-1].bias.data.copy_(bias) | ||
|
||
else: | ||
stack.append(m) | ||
|
||
if len(stack) > 1: | ||
return nn.Sequential(*stack) | ||
else: | ||
return stack[0] | ||
|
||
|
||
def fuse_bn_recursively(model): | ||
for module_name in model._modules: | ||
model._modules[module_name] = fuse_bn_sequential(model._modules[module_name]) | ||
if len(model._modules[module_name]._modules) > 0: | ||
fuse_bn_recursively(model._modules[module_name]) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
|
||
from blendbatchnorm import fuse_bn_recursively | ||
from meshnet2tfjs import meshnet2tfjs | ||
|
||
from meshnet import ( | ||
MeshNet, | ||
enMesh_checkpoint, | ||
) | ||
|
||
device_name = "cuda:0" if torch.cuda.is_available() else "cpu" | ||
device = torch.device(device_name) | ||
|
||
|
||
def preprocess_image(img, qmin=0.01, qmax=0.99): | ||
"""Unit interval preprocessing""" | ||
img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin)) | ||
return img | ||
|
||
|
||
# def preprocess_image(img): | ||
# """Unit interval preprocessing""" | ||
# img = (img - img.min()) / (img.max() - img.min()) | ||
# return img | ||
|
||
|
||
# specify how many classes does the model predict | ||
n_classes = 3 | ||
# specify the architecture | ||
config_file = "modelAE.json" | ||
# how many channels does the saved model have | ||
model_channels = 15 | ||
# path to the saved model | ||
model_path = "model.pth" | ||
# tfjs model output directory | ||
tfjs_model_dir = "model_tfjs" | ||
|
||
meshnet_model = enMesh_checkpoint( | ||
in_channels=1, | ||
n_classes=n_classes, | ||
channels=model_channels, | ||
config_file=config_file, | ||
) | ||
|
||
checkpoint = torch.load(model_path) | ||
meshnet_model.load_state_dict(checkpoint) | ||
|
||
meshnet_model.eval() | ||
|
||
meshnet_model.to(device) | ||
mnm = fuse_bn_recursively(meshnet_model) | ||
del meshnet_model | ||
mnm.model.eval() | ||
|
||
|
||
meshnet2tfjs(mnm, tfjs_model_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.utils.checkpoint import checkpoint_sequential | ||
import json | ||
|
||
|
||
def set_channel_num(config, in_channels, n_classes, channels): | ||
""" | ||
Takes a configuration json for a convolutional neural network of MeshNet architecture and changes it to have the specified number of input channels, output classes, and number of channels that each layer except the input and output layers have. | ||
Args: | ||
config (dict): The configuration json for the network. | ||
in_channels (int): The number of input channels. | ||
n_classes (int): The number of output classes. | ||
channels (int): The number of channels that each layer except the input and output layers will have. | ||
Returns: | ||
dict: The updated configuration json. | ||
""" | ||
# input layer | ||
config["layers"][0]["in_channels"] = in_channels | ||
config["layers"][0]["out_channels"] = channels | ||
|
||
# output layer | ||
config["layers"][-1]["in_channels"] = channels | ||
config["layers"][-1]["out_channels"] = n_classes | ||
|
||
# hidden layers | ||
for layer in config["layers"][1:-1]: | ||
layer["in_channels"] = layer["out_channels"] = channels | ||
|
||
return config | ||
|
||
|
||
def construct_layer(dropout_p=0, bnorm=True, gelu=False, *args, **kwargs): | ||
"""Constructs a configurable Convolutional block with Batch Normalization and Dropout. | ||
Args: | ||
dropout_p (float): Dropout probability. Default is 0. | ||
bnorm (bool): Whether to include batch normalization. Default is True. | ||
gelu (bool): Whether to use GELU activation. Default is False. | ||
*args: Additional positional arguments to pass to nn.Conv3d. | ||
**kwargs: Additional keyword arguments to pass to nn.Conv3d. | ||
Returns: | ||
nn.Sequential: A sequential container of Convolutional block with optional Batch Normalization and Dropout. | ||
""" | ||
layers = [] | ||
layers.append(nn.Conv3d(*args, **kwargs)) | ||
if bnorm: | ||
# track_running_stats=False is needed to run the forward mode AD | ||
layers.append( | ||
nn.BatchNorm3d(kwargs["out_channels"], track_running_stats=True) | ||
) | ||
layers.append(nn.ELU(inplace=True) if gelu else nn.ReLU(inplace=True)) | ||
if dropout_p > 0: | ||
layers.append(nn.Dropout3d(dropout_p)) | ||
return nn.Sequential(*layers) | ||
|
||
|
||
def init_weights(model): | ||
"""Set weights to be xavier normal for all Convs""" | ||
for m in model.modules(): | ||
if isinstance( | ||
m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d) | ||
): | ||
# nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain("relu")) | ||
nn.init.kaiming_normal_( | ||
m.weight, mode="fan_out", nonlinearity="relu" | ||
) | ||
if m.bias is not None: | ||
nn.init.constant_(m.bias, 0.0) | ||
|
||
|
||
class MeshNet(nn.Module): | ||
"""Configurable MeshNet from https://arxiv.org/pdf/1612.00940.pdf""" | ||
|
||
def __init__(self, in_channels, n_classes, channels, config_file, fat=None): | ||
"""Init""" | ||
with open(config_file, "r") as f: | ||
config = set_channel_num( | ||
json.load(f), in_channels, n_classes, channels | ||
) | ||
|
||
if fat is not None: | ||
chn = int(channels * 1.5) | ||
if fat in {"i", "io"}: | ||
config["layers"][0]["out_channels"] = chn | ||
config["layers"][1]["in_channels"] = chn | ||
if fat == "io": | ||
config["layers"][-1]["in_channels"] = chn | ||
config["layers"][-2]["out_channels"] = chn | ||
if fat == "b": | ||
config["layers"][3]["out_channels"] = chn | ||
config["layers"][4]["in_channels"] = chn | ||
|
||
super(MeshNet, self).__init__() | ||
|
||
layers = [ | ||
construct_layer( | ||
dropout_p=config["dropout_p"], | ||
bnorm=config["bnorm"], | ||
gelu=config["gelu"], | ||
**block_kwargs, | ||
) | ||
for block_kwargs in config["layers"] | ||
] | ||
layers[-1] = layers[-1][0] | ||
self.model = nn.Sequential(*layers) | ||
init_weights(self.model) | ||
|
||
def forward(self, x): | ||
"""Forward pass""" | ||
x = self.model(x) | ||
return x | ||
|
||
|
||
class enMesh_checkpoint(MeshNet): | ||
def train_forward(self, x): | ||
y = x | ||
y.requires_grad_() | ||
y = checkpoint_sequential( | ||
self.model, len(self.model), y, preserve_rng_state=False | ||
) | ||
return y | ||
|
||
def eval_forward(self, x): | ||
"""Forward pass""" | ||
self.model.eval() | ||
with torch.inference_mode(): | ||
x = self.model(x) | ||
return x | ||
|
||
def forward(self, x): | ||
if self.training: | ||
return self.train_forward(x) | ||
else: | ||
return self.eval_forward(x) |
Oops, something went wrong.