Skip to content

Commit 2fe1946

Browse files
committed
Add generic locally_connected wrapper around locally_connected1d
1 parent 1d3ce3a commit 2fe1946

File tree

5 files changed

+41
-37
lines changed

5 files changed

+41
-37
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
3333
| Embedding | `embedding` | n/a | 2 |||
3434
| Dense (fully-connected) | `dense` | `input1d`, `dense`, `dropout`, `flatten` | 1 |||
3535
| Dropout | `dropout` | `dense`, `flatten`, `input1d` | 1 |||
36-
| Locally connected (1-d) | `locally_connected1d` | `input2d`, `locally_connected1d`, `conv1d`, `maxpool1d`, `reshape2d` | 2 |||
36+
| Locally connected (1-d) | `locally_connected` | `input`, `locally_connected`, `conv`, `maxpool`, `reshape` | 2 |||
3737
| Convolutional (1-d and 2-d) | `conv` | `input`, `conv`, `maxpool`, `reshape` | 2, 3 |||
3838
| Max-pooling (1-d and 2-d) | `maxpool` | `input`, `conv`, `maxpool`, `reshape` | 2, 3 |||
3939
| Linear (2-d) | `linear2d` | `input2d`, `layernorm`, `linear2d`, `self_attention` | 2 |||

example/cnn_mnist_1d.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program cnn_mnist_1d
22

33
use nf, only: network, sgd, &
4-
input, maxpool, flatten, dense, reshape, locally_connected1d, &
4+
input, maxpool, flatten, dense, reshape, locally_connected, &
55
load_mnist, label_digits, softmax, relu
66

77
implicit none
@@ -21,9 +21,9 @@ program cnn_mnist_1d
2121
net = network([ &
2222
input(784), &
2323
reshape(28, 28), &
24-
locally_connected1d(filters=8, kernel_size=3, activation=relu()), &
24+
locally_connected(filters=8, kernel_size=3, activation=relu()), &
2525
maxpool(pool_width=2, stride=2), &
26-
locally_connected1d(filters=16, kernel_size=3, activation=relu()), &
26+
locally_connected(filters=16, kernel_size=3, activation=relu()), &
2727
maxpool(pool_width=2, stride=2), &
2828
dense(10, activation=softmax()) &
2929
])

src/nf.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ module nf
1111
input, &
1212
layernorm, &
1313
linear2d, &
14-
locally_connected1d, &
14+
locally_connected, &
1515
maxpool, &
1616
reshape, &
1717
self_attention

src/nf/nf_layer_constructors.f90

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ module nf_layer_constructors
1515
flatten, &
1616
input, &
1717
linear2d, &
18-
locally_connected1d, &
18+
locally_connected, &
1919
maxpool, &
2020
reshape, &
2121
self_attention, &
@@ -154,6 +154,38 @@ end function conv2d
154154
end interface conv
155155

156156

157+
interface locally_connected
158+
159+
module function locally_connected1d(filters, kernel_size, activation) result(res)
160+
!! 1-d locally connected network constructor
161+
!!
162+
!! This layer is for building 1-d locally connected network.
163+
!! Although the established convention is to call these layers 1-d,
164+
!! the shape of the data is actually 2-d: image width,
165+
!! and the number of channels.
166+
!! A locally connected 1d layer must not be the first layer in the network.
167+
!!
168+
!! Example:
169+
!!
170+
!! ```
171+
!! use nf, only :: locally_connected1d, layer
172+
!! type(layer) :: locally_connected1d_layer
173+
!! locally_connected1d_layer = dense(filters=32, kernel_size=3)
174+
!! locally_connected1d_layer = dense(filters=32, kernel_size=3, activation='relu')
175+
!! ```
176+
integer, intent(in) :: filters
177+
!! Number of filters in the output of the layer
178+
integer, intent(in) :: kernel_size
179+
!! Width of the convolution window, commonly 3 or 5
180+
class(activation_function), intent(in), optional :: activation
181+
!! Activation function (default sigmoid)
182+
type(layer) :: res
183+
!! Resulting layer instance
184+
end function locally_connected1d
185+
186+
end interface locally_connected
187+
188+
157189
interface maxpool
158190

159191
module function maxpool1d(pool_width, stride) result(res)
@@ -290,33 +322,6 @@ module function flatten() result(res)
290322
!! Resulting layer instance
291323
end function flatten
292324

293-
module function locally_connected1d(filters, kernel_size, activation) result(res)
294-
!! 1-d locally connected network constructor
295-
!!
296-
!! This layer is for building 1-d locally connected network.
297-
!! Although the established convention is to call these layers 1-d,
298-
!! the shape of the data is actually 2-d: image width,
299-
!! and the number of channels.
300-
!! A locally connected 1d layer must not be the first layer in the network.
301-
!!
302-
!! Example:
303-
!!
304-
!! ```
305-
!! use nf, only :: locally_connected1d, layer
306-
!! type(layer) :: locally_connected1d_layer
307-
!! locally_connected1d_layer = dense(filters=32, kernel_size=3)
308-
!! locally_connected1d_layer = dense(filters=32, kernel_size=3, activation='relu')
309-
!! ```
310-
integer, intent(in) :: filters
311-
!! Number of filters in the output of the layer
312-
integer, intent(in) :: kernel_size
313-
!! Width of the convolution window, commonly 3 or 5
314-
class(activation_function), intent(in), optional :: activation
315-
!! Activation function (default sigmoid)
316-
type(layer) :: res
317-
!! Resulting layer instance
318-
end function locally_connected1d
319-
320325
module function linear2d(out_features) result(res)
321326
!! Rank-2 (sequence_length, out_features) linear layer constructor.
322327
!! sequence_length is determined at layer initialization, based on the

test/test_locally_connected1d_layer.f90

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program test_locally_connected1d_layer
22

33
use iso_fortran_env, only: stderr => error_unit
4-
use nf, only: locally_connected1d, input, layer
4+
use nf, only: locally_connected, input, layer
55
use nf_input2d_layer, only: input2d_layer
66

77
implicit none
@@ -12,7 +12,7 @@ program test_locally_connected1d_layer
1212
real, parameter :: tolerance = 1e-7
1313
logical :: ok = .true.
1414

15-
locally_connected_1d_layer = locally_connected1d(filters, kernel_size)
15+
locally_connected_1d_layer = locally_connected(filters, kernel_size)
1616

1717
if (.not. locally_connected_1d_layer % name == 'locally_connected1d') then
1818
ok = .false.
@@ -52,7 +52,7 @@ program test_locally_connected1d_layer
5252
sample_input = 0
5353

5454
input_layer = input(1, 3)
55-
locally_connected_1d_layer = locally_connected1d(filters, kernel_size)
55+
locally_connected_1d_layer = locally_connected(filters, kernel_size)
5656
call locally_connected_1d_layer % init(input_layer)
5757

5858
select type(this_layer => input_layer % p); type is(input2d_layer)
@@ -62,7 +62,6 @@ program test_locally_connected1d_layer
6262
call locally_connected_1d_layer % forward(input_layer)
6363
call locally_connected_1d_layer % get_output(output)
6464

65-
6665
if (.not. all(abs(output) < tolerance)) then
6766
ok = .false.
6867
write(stderr, '(a)') 'locally_connected1d layer with zero input and sigmoid function must forward to all 0.5.. failed'

0 commit comments

Comments
 (0)