-
Notifications
You must be signed in to change notification settings - Fork 17
/
norm_example.py
58 lines (44 loc) · 1.67 KB
/
norm_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import nobuco
from nobuco import ChannelOrder, ChannelOrderingStrategy
from nobuco.layers.weight import WeightLayer
import torch
from torch import nn
from torchvision.ops import FrozenBatchNorm2d
import tensorflow as tf
from tensorflow.lite.python.lite import TFLiteConverter
import keras
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.batch_norm = nn.BatchNorm2d(32)
self.frozen_batch_norm = FrozenBatchNorm2d(32)
self.instance_norm = nn.InstanceNorm2d(32, affine=True)
self.group_norm = nn.GroupNorm(num_groups=4, num_channels=32, affine=True)
self.layer_norm = nn.LayerNorm(normalized_shape=256, elementwise_affine=True)
def forward(self, x):
x1 = self.batch_norm(x)
x2 = self.frozen_batch_norm(x)
x3 = self.instance_norm(x)
x4 = self.group_norm(x)
x5 = self.layer_norm(x)
return x1, x2, x3, x4, x5
model = DummyModel().eval()
x = torch.randn(4, 32, 256, 256)
keras_model = nobuco.pytorch_to_keras(
model,
args=[x],
# inputs_channel_order=ChannelOrder.PYTORCH,
)
model_path = 'norm'
keras_model.save(model_path + '.h5')
print('Model saved')
custom_objects = {'WeightLayer': WeightLayer}
keras_model_restored = keras.models.load_model(model_path + '.h5', custom_objects=custom_objects)
print('Model loaded')
converter = TFLiteConverter.from_keras_model_file(model_path + '.h5', custom_objects=custom_objects)
converter.target_ops = [tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS]
tflite_model = converter.convert()
with open(model_path + '.tflite', 'wb') as f:
f.write(tflite_model)