Skip to content

Commit

Permalink
First commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaochus committed Feb 1, 2018
1 parent 5054a2e commit 6e433be
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 0 deletions.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# MobileNet v2
A Keras 2 implementation of MobileNet V2.

According to the paper:[Inverted Residuals and Linear Bottlenecks Mobile Networks for Classification, Detection and Segmentation](https://arxiv.org/abs/1801.04381)

Currently only the network structure is defined, and the training function will be updated later.

## Requirement
- Python 3.5
- Tensorflow-gpu 1.2.0
- Keras 2.1.3


## MobileNet v2 and inverted residual block architectures
**MobileNet v2:**

Each line describes a sequence of 1 or more identical (modulo stride) layers, repeated n times. All layers in the same sequence have the same number c of output channels. The first layer of each sequence has a stride s and all others use stride 1. All spatial convolutions use 3 X 3 kernels. The expansion factor t is always applied to the input size.

![MobileNetV2](/images/net.jpg)

**Residual Block Architectures:**

![residual block architectures](/images/stru.jpg)

**Architectures of this implementation with (224, 224, 3) inputs and 1000 output:**

![architectures](/images/MobileNetv2.png)

##Reference
- [Inverted Residuals and Linear Bottlenecks Mobile Networks for Classification, Detection and Segmentation](https://arxiv.org/abs/1801.04381)

##Copyright
See [LICENSE](LICENSE) for details.


Binary file added images/MobileNetv2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/net.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/stru.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
146 changes: 146 additions & 0 deletions mobilenet_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""MobileNet v2 models for Keras.
# Reference
- [Inverted Residuals and Linear Bottlenecks Mobile Networks for
Classification, Detection and Segmentation]
(https://arxiv.org/abs/1801.04381)
"""


from keras.models import Model
from keras.layers import Input, Conv2D, AveragePooling2D, Dropout
from keras.layers import Activation, BatchNormalization, add
from keras.applications.mobilenet import relu6, DepthwiseConv2D
from keras.utils.vis_utils import plot_model


from keras import backend as K


def _conv_block(inputs, filters, kernel, strides):
"""Convolution Block
This function defines a 2D convolution operation with BN and relu6.
# Arguments
inputs: Tensor, input tensor of conv layer.
filters: Integer, the dimensionality of the output space.
kernel: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution along the width and height.
Can be a single integer to specify the same value for
all spatial dimensions.
# Returns
Output tensor.
"""

channel_axis = 1 if K.image_data_format() == 'channels_first' else -1

x = Conv2D(filters, kernel, padding='same', strides=strides)(inputs)
x = BatchNormalization(axis=channel_axis)(x)
return Activation(relu6)(x)


def _bottleneck(inputs, filters, kernel, t, s, r=False):
"""Bottleneck
This function defines a basic bottleneck structure.
# Arguments
inputs: Tensor, input tensor of conv layer.
filters: Integer, the dimensionality of the output space.
kernel: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
t: Integer, expansion factor.
t is always applied to the input size.
s: An integer or tuple/list of 2 integers,specifying the strides
of the convolution along the width and height.Can be a single
integer to specify the same value for all spatial dimensions.
r: Boolean, Whether to use the residuals.
# Returns
Output tensor.
"""

channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
tchannel = K.int_shape(inputs)[channel_axis] * t

x = _conv_block(inputs, tchannel, (1, 1), (1, 1))

x = DepthwiseConv2D(kernel, strides=(s, s), depth_multiplier=1, padding='same')(x)
x = BatchNormalization(axis=channel_axis)(x)
x = Activation(relu6)(x)

x = Conv2D(filters, (1, 1), strides=(1, 1), padding='same')(x)
x = BatchNormalization(axis=channel_axis)(x)

if r:
x = add([x, inputs])
return x


def _inverted_residual_block(inputs, filters, kernel, t, strides, n):
"""Inverted Residual Block
This function defines a sequence of 1 or more identical layers.
# Arguments
inputs: Tensor, input tensor of conv layer.
filters: Integer, the dimensionality of the output space.
kernel: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
t: Integer, expansion factor.
t is always applied to the input size.
s: An integer or tuple/list of 2 integers,specifying the strides
of the convolution along the width and height.Can be a single
integer to specify the same value for all spatial dimensions.
n: Integer, layer repeat times.
# Returns
Output tensor.
"""

x = _bottleneck(inputs, filters, kernel, t, strides)

for i in range(1, n):
x = _bottleneck(x, filters, kernel, t, 1, True)

return x


def MobileNetv2(input_shape, k):
"""MobileNetv2
This function defines a MobileNetv2 architectures.
# Arguments
input_shape: An integer or tuple/list of 3 integers, shape
of input tensor.
k: Integer, layer repeat times.
# Returns
MobileNetv2 model.
"""

inputs = Input(shape=input_shape)
x = _conv_block(inputs, 32, (3, 3), strides=(2, 2))

x = _inverted_residual_block(x, 16, (3, 3), t=1, strides=1, n=1)
x = _inverted_residual_block(x, 24, (3, 3), t=6, strides=2, n=2)
x = _inverted_residual_block(x, 32, (3, 3), t=6, strides=2, n=3)
x = _inverted_residual_block(x, 64, (3, 3), t=6, strides=2, n=4)
x = _inverted_residual_block(x, 96, (3, 3), t=6, strides=1, n=3)
x = _inverted_residual_block(x, 160, (3, 3), t=6, strides=2, n=3)
x = _inverted_residual_block(x, 320, (3, 3), t=6, strides=1, n=1)

x = _conv_block(x, 1280, (1, 1), strides=(1, 1))
# x = GlobalAveragePooling2D()(x)
x = AveragePooling2D((int(x.shape[1]), int(x.shape[2])))(x)
x = Dropout(0.3)(x)
x = Conv2D(k, (1, 1), padding='same')(x)
output = Activation('softmax', name='softmax')(x)

model = Model(inputs, output)
plot_model(model, to_file='images/MobileNetv2.png', show_shapes=True)

return model


if __name__ == '__main__':
MobileNetv2((224, 224, 3), 1000)

0 comments on commit 6e433be

Please sign in to comment.