Skip to content

Commit a37a0c6

Browse files
committed
Generic maxpool constructor for maxpool1d_layer and maxpool2d_layer
1 parent 1c3defe commit a37a0c6

12 files changed

+95
-100
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
3535
| Dropout | `dropout` | `dense`, `flatten`, `input1d` | 1 |||
3636
| Locally connected (1-d) | `locally_connected1d` | `input2d`, `locally_connected1d`, `conv1d`, `maxpool1d`, `reshape2d` | 2 |||
3737
| Convolutional (1-d and 2-d) | `conv` | `input`, `conv`, `maxpool`, `reshape` | 2, 3 |||
38-
| Max-pooling (1-d) | `maxpool1d` | `input2d`, `conv1d`, `maxpool1d`, `reshape2d` | 2 |||
39-
| Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 |||
38+
| Max-pooling (1-d and 2-d) | `maxpool` | `input`, `conv`, `maxpool`, `reshape` | 2, 3 |||
4039
| Linear (2-d) | `linear2d` | `input2d`, `layernorm`, `linear2d`, `self_attention` | 2 |||
4140
| Self-attention | `self_attention` | `input2d`, `layernorm`, `linear2d`, `self_attention` | 2 |||
4241
| Layer Normalization | `layernorm` | `linear2d`, `self_attention` | 2 |||

example/cnn_mnist.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program cnn_mnist
22

33
use nf, only: network, sgd, &
4-
input, conv, maxpool2d, flatten, dense, reshape, &
4+
input, conv, maxpool, flatten, dense, reshape, &
55
load_mnist, label_digits, softmax, relu
66

77
implicit none
@@ -22,9 +22,9 @@ program cnn_mnist
2222
input(784), &
2323
reshape(1, 28, 28), &
2424
conv(filters=8, kernel_width=3, kernel_height=3, activation=relu()), &
25-
maxpool2d(pool_size=2), &
25+
maxpool(pool_width=2, stride=2), &
2626
conv(filters=16, kernel_width=3, kernel_height=3, activation=relu()), &
27-
maxpool2d(pool_size=2), &
27+
maxpool(pool_width=2, stride=2), &
2828
dense(10, activation=softmax()) &
2929
])
3030

example/cnn_mnist_1d.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program cnn_mnist_1d
22

33
use nf, only: network, sgd, &
4-
input, maxpool1d, flatten, dense, reshape, locally_connected1d, &
4+
input, maxpool, flatten, dense, reshape, locally_connected1d, &
55
load_mnist, label_digits, softmax, relu
66

77
implicit none
@@ -22,9 +22,9 @@ program cnn_mnist_1d
2222
input(784), &
2323
reshape(28, 28), &
2424
locally_connected1d(filters=8, kernel_size=3, activation=relu()), &
25-
maxpool1d(pool_size=2), &
25+
maxpool(pool_width=2, stride=2), &
2626
locally_connected1d(filters=16, kernel_size=3, activation=relu()), &
27-
maxpool1d(pool_size=2), &
27+
maxpool(pool_width=2, stride=2), &
2828
dense(10, activation=softmax()) &
2929
])
3030

src/nf.f90

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ module nf
1212
layernorm, &
1313
linear2d, &
1414
locally_connected1d, &
15-
maxpool1d, &
16-
maxpool2d, &
15+
maxpool, &
1716
reshape, &
1817
self_attention
1918
use nf_loss, only: mse, quadratic

src/nf/nf_layer_constructors.f90

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ module nf_layer_constructors
1616
input, &
1717
linear2d, &
1818
locally_connected1d, &
19-
maxpool1d, &
20-
maxpool2d, &
19+
maxpool, &
2120
reshape, &
2221
self_attention, &
2322
embedding, &
@@ -151,10 +150,61 @@ module function conv2d(filters, kernel_width, kernel_height, activation) result(
151150
type(layer) :: res
152151
!! Resulting layer instance
153152
end function conv2d
154-
153+
155154
end interface conv
156155

157156

157+
interface maxpool
158+
159+
module function maxpool1d(pool_width, stride) result(res)
160+
!! 1-d maxpooling layer constructor.
161+
!!
162+
!! This layer is for downscaling other layers, typically `conv1d`.
163+
!!
164+
!! This specific function is available under a generic name `maxpool`.
165+
!!
166+
!! Example:
167+
!!
168+
!! ```
169+
!! use nf, only :: maxpool1d, layer
170+
!! type(layer) :: maxpool1d_layer
171+
!! maxpool1d_layer = maxpool1d(pool_width=2, stride=2)
172+
!! ```
173+
integer, intent(in) :: pool_width
174+
!! Width of the pooling window, commonly 2
175+
integer, intent(in) :: stride
176+
!! Stride of the pooling window, commonly equal to `pool_width`;
177+
type(layer) :: res
178+
!! Resulting layer instance
179+
end function maxpool1d
180+
181+
module function maxpool2d(pool_width, pool_height, stride) result(res)
182+
!! 2-d maxpooling layer constructor.
183+
!!
184+
!! This layer is for downscaling other layers, typically `conv2d`.
185+
!!
186+
!! This specific function is available under a generic name `maxpool`.
187+
!!
188+
!! Example:
189+
!!
190+
!! ```
191+
!! use nf, only :: maxpool2d, layer
192+
!! type(layer) :: maxpool2d_layer
193+
!! maxpool2d_layer = maxpool2d(pool_width=2, pool_height=2, stride=2)
194+
!! ```
195+
integer, intent(in) :: pool_width
196+
!! Width of the pooling window, commonly 2
197+
integer, intent(in) :: pool_height
198+
!! Height of the pooling window; currently must be equal to pool_width
199+
integer, intent(in) :: stride
200+
!! Stride of the pooling window, commonly equal to `pool_width`;
201+
type(layer) :: res
202+
!! Resulting layer instance
203+
end function maxpool2d
204+
205+
end interface maxpool
206+
207+
158208
interface reshape
159209

160210
module function reshape2d(dim1, dim2) result(res)
@@ -267,50 +317,6 @@ module function locally_connected1d(filters, kernel_size, activation) result(res
267317
!! Resulting layer instance
268318
end function locally_connected1d
269319

270-
module function maxpool1d(pool_size, stride) result(res)
271-
!! 1-d maxpooling layer constructor.
272-
!!
273-
!! This layer is for downscaling other layers, typically `conv1d`.
274-
!!
275-
!! Example:
276-
!!
277-
!! ```
278-
!! use nf, only :: maxpool1d, layer
279-
!! type(layer) :: maxpool1d_layer
280-
!! maxpool1d_layer = maxpool1d(pool_size=2)
281-
!! maxpool1d_layer = maxpool1d(pool_size=2, stride=3)
282-
!! ```
283-
integer, intent(in) :: pool_size
284-
!! Width of the pooling window, commonly 2
285-
integer, intent(in), optional :: stride
286-
!! Stride of the pooling window, commonly equal to `pool_size`;
287-
!! Defaults to `pool_size` if omitted.
288-
type(layer) :: res
289-
!! Resulting layer instance
290-
end function maxpool1d
291-
292-
module function maxpool2d(pool_size, stride) result(res)
293-
!! 2-d maxpooling layer constructor.
294-
!!
295-
!! This layer is for downscaling other layers, typically `conv2d`.
296-
!!
297-
!! Example:
298-
!!
299-
!! ```
300-
!! use nf, only :: maxpool2d, layer
301-
!! type(layer) :: maxpool2d_layer
302-
!! maxpool2d_layer = maxpool2d(pool_size=2)
303-
!! maxpool2d_layer = maxpool2d(pool_size=2, stride=3)
304-
!! ```
305-
integer, intent(in) :: pool_size
306-
!! Width of the pooling window, commonly 2
307-
integer, intent(in), optional :: stride
308-
!! Stride of the pooling window, commonly equal to `pool_size`;
309-
!! Defaults to `pool_size` if omitted.
310-
type(layer) :: res
311-
!! Resulting layer instance
312-
end function maxpool2d
313-
314320
module function linear2d(out_features) result(res)
315321
!! Rank-2 (sequence_length, out_features) linear layer constructor.
316322
!! sequence_length is determined at layer initialization, based on the

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -179,58 +179,49 @@ module function input3d(dim1, dim2, dim3) result(res)
179179
res % initialized = .true.
180180
end function input3d
181181

182-
module function maxpool1d(pool_size, stride) result(res)
183-
integer, intent(in) :: pool_size
184-
integer, intent(in), optional :: stride
185-
integer :: stride_
182+
module function maxpool1d(pool_width, stride) result(res)
183+
integer, intent(in) :: pool_width
184+
integer, intent(in) :: stride
186185
type(layer) :: res
187186

188-
if (pool_size < 2) &
189-
error stop 'pool_size must be >= 2 in a maxpool1d layer'
187+
if (pool_width < 2) &
188+
error stop 'pool_width must be >= 2 in a maxpool1d layer'
190189

191-
! Stride defaults to pool_size if not provided
192-
if (present(stride)) then
193-
stride_ = stride
194-
else
195-
stride_ = pool_size
196-
end if
197-
198-
if (stride_ < 1) &
190+
if (stride < 1) &
199191
error stop 'stride must be >= 1 in a maxpool1d layer'
200192

201193
res % name = 'maxpool1d'
202194

203195
allocate( &
204196
res % p, &
205-
source=maxpool1d_layer(pool_size, stride_) &
197+
source=maxpool1d_layer(pool_width, stride) &
206198
)
207199

208200
end function maxpool1d
209201

210-
module function maxpool2d(pool_size, stride) result(res)
211-
integer, intent(in) :: pool_size
212-
integer, intent(in), optional :: stride
213-
integer :: stride_
202+
module function maxpool2d(pool_width, pool_height, stride) result(res)
203+
integer, intent(in) :: pool_width
204+
integer, intent(in) :: pool_height
205+
integer, intent(in) :: stride
214206
type(layer) :: res
215207

216-
if (pool_size < 2) &
217-
error stop 'pool_size must be >= 2 in a maxpool2d layer'
208+
if (pool_width < 2) &
209+
error stop 'pool_width must be >= 2 in a maxpool2d layer'
218210

219-
! Stride defaults to pool_size if not provided
220-
if (present(stride)) then
221-
stride_ = stride
222-
else
223-
stride_ = pool_size
224-
end if
211+
! Enforce pool_width == pool_height for now;
212+
! If non-square poolings show to be desired, we'll relax this constraint
213+
! and refactor maxpool2d_layer to work with non-square kernels.
214+
if (pool_width /= pool_height) &
215+
error stop 'pool_width must equal pool_height in a maxpool2d layer'
225216

226-
if (stride_ < 1) &
217+
if (stride < 1) &
227218
error stop 'stride must be >= 1 in a maxpool2d layer'
228219

229220
res % name = 'maxpool2d'
230221

231222
allocate( &
232223
res % p, &
233-
source=maxpool2d_layer(pool_size, stride_) &
224+
source=maxpool2d_layer(pool_width, stride) &
234225
)
235226

236227
end function maxpool2d

test/test_conv1d_network.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program test_conv1d_network
22

33
use iso_fortran_env, only: stderr => error_unit
4-
use nf, only: conv, input, network, dense, sgd, maxpool1d
4+
use nf, only: conv, input, network, dense, sgd, maxpool
55

66
implicit none
77

@@ -87,7 +87,7 @@ program test_conv1d_network
8787
cnn = network([ &
8888
input(1, 8), &
8989
conv(filters=1, kernel_width=3), &
90-
maxpool1d(pool_size=2), &
90+
maxpool(pool_width=2, stride=2), &
9191
conv(filters=1, kernel_width=3), &
9292
dense(1) &
9393
])
@@ -122,7 +122,7 @@ program test_conv1d_network
122122
cnn = network([ &
123123
input(1, 12), &
124124
conv(filters=1, kernel_width=3), & ! 1x12x12 input, 1x10x10 output
125-
maxpool1d(pool_size=2), & ! 1x10x10 input, 1x5x5 output
125+
maxpool(pool_width=2, stride=2), & ! 1x10x10 input, 1x5x5 output
126126
conv(filters=1, kernel_width=3), & ! 1x5x5 input, 1x3x3 output
127127
dense(9) & ! 9 outputs
128128
])

test/test_conv2d_network.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program test_conv2d_network
22

33
use iso_fortran_env, only: stderr => error_unit
4-
use nf, only: conv, input, network, dense, sgd, maxpool2d
4+
use nf, only: conv, input, network, dense, sgd, maxpool
55

66
implicit none
77

@@ -87,7 +87,7 @@ program test_conv2d_network
8787
cnn = network([ &
8888
input(1, 8, 8), &
8989
conv(filters=1, kernel_width=3, kernel_height=3), &
90-
maxpool2d(pool_size=2), &
90+
maxpool(pool_width=2, pool_height=2, stride=2), &
9191
conv(filters=1, kernel_width=3, kernel_height=3), &
9292
dense(1) &
9393
])
@@ -122,7 +122,7 @@ program test_conv2d_network
122122
cnn = network([ &
123123
input(1, 12, 12), &
124124
conv(filters=1, kernel_width=3, kernel_height=3), & ! 1x12x12 input, 1x10x10 output
125-
maxpool2d(pool_size=2), & ! 1x10x10 input, 1x5x5 output
125+
maxpool(pool_width=2, pool_height=2, stride=2), & ! 1x10x10 input, 1x5x5 output
126126
conv(filters=1, kernel_width=3, kernel_height=3), & ! 1x5x5 input, 1x3x3 output
127127
dense(9) & ! 9 outputs
128128
])

test/test_get_set_network_params.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
program test_get_set_network_params
22
use iso_fortran_env, only: stderr => error_unit
3-
use nf, only: conv, dense, flatten, input, maxpool2d, network
3+
use nf, only: conv, dense, flatten, input, network
44
implicit none
55
type(network) :: net
66
logical :: ok = .true.

test/test_insert_flatten.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program test_insert_flatten
22

33
use iso_fortran_env, only: stderr => error_unit
4-
use nf, only: network, input, conv, maxpool2d, flatten, dense, reshape
4+
use nf, only: network, input, conv, maxpool, flatten, dense, reshape
55

66
implicit none
77

@@ -34,13 +34,13 @@ program test_insert_flatten
3434
net = network([ &
3535
input(3, 32, 32), &
3636
conv(filters=1, kernel_width=3, kernel_height=3), &
37-
maxpool2d(pool_size=2, stride=2), &
37+
maxpool(pool_width=2, stride=2), &
3838
dense(10) &
3939
])
4040

4141
if (.not. net % layers(4) % name == 'flatten') then
4242
ok = .false.
43-
write(stderr, '(a)') 'flatten layer inserted after maxpool2d.. failed'
43+
write(stderr, '(a)') 'flatten layer inserted after maxpool.. failed'
4444
end if
4545

4646
net = network([ &

0 commit comments

Comments
 (0)