Skip to content

Commit 553a55e

Browse files
committed
Implementation of stride
1 parent 824bb13 commit 553a55e

File tree

5 files changed

+68
-30
lines changed

5 files changed

+68
-30
lines changed

src/nf/nf_conv1d_layer.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ module subroutine init(self, input_shape)
6363
!! Input layer dimensions
6464
end subroutine init
6565

66-
module subroutine forward(self, input)
66+
pure module subroutine forward(self, input)
6767
!! Apply a forward pass on the `conv1d` layer.
6868
class(conv1d_layer), intent(in out) :: self
6969
!! A `conv1d_layer` instance

src/nf/nf_conv1d_layer_submodule.f90

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ module subroutine init(self, input_shape)
5757

5858
end subroutine init
5959

60-
module subroutine forward(self, input)
60+
pure module subroutine forward(self, input)
6161
implicit none
6262
class(conv1d_layer), intent(in out) :: self
6363
real, intent(in) :: input(:,:)
@@ -125,13 +125,13 @@ pure module subroutine backward(self, input, gradient)
125125
do n = 1, self % filters
126126
do j = 1, self % width
127127
iws = self % stride * (j-1) + 1
128-
iwe = max(iws + self % kernel_size - 1, input_width)
128+
iwe = min(iws + self % kernel_size - 1, input_width)
129129

130130
do k = 1, self % channels
131131
! Weight gradient: accumulate contribution from the input window.
132-
dw_local(n,k,1:iws-iwe+1) = dw_local(n,k,1:iws-iwe+1) + input(k,iws:iwe) * gdz(n,j)
132+
dw_local(n,k,1:iwe-iws+1) = dw_local(n,k,1:iwe-iws+1) + input(k,iws:iwe) * gdz(n,j)
133133
! Input gradient: propagate gradient back to the input window.
134-
self % gradient(k,iws:iwe) = self % gradient(k,iws:iwe) + self % kernel(n,k,1:iws-iwe+1) * gdz(n,j)
134+
self % gradient(k,iws:iwe) = self % gradient(k,iws:iwe) + self % kernel(n,k,1:iwe-iws+1) * gdz(n,j)
135135
end do
136136
end do
137137
end do

src/nf/nf_conv2d_layer_submodule.f90

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,12 @@ module subroutine init(self, input_shape)
3030
integer, intent(in) :: input_shape(:)
3131

3232
self % channels = input_shape(1)
33-
self % width = (input_shape(2) - self % kernel_size + 1) / self % stride(1)
34-
self % height = (input_shape(3) - self % kernel_size + 1) / self % stride(2)
33+
34+
self % width = (input_shape(2) - self % kernel_size) / self % stride(1) + 1
35+
if (mod(input_shape(2) - self % kernel_size , self % stride(1)) /= 0) self % width = self % width + 1
36+
37+
self % height = (input_shape(3) - self % kernel_size) / self % stride(2) + 1
38+
if (mod(input_shape(3) - self % kernel_size , self % stride(2)) /= 0) self % height = self % height + 1
3539

3640
! Output of shape filters x width x height
3741
allocate(self % output(self % filters, self % width, self % height))
@@ -89,22 +93,24 @@ pure module subroutine forward(self, input)
8993
iend = input_width - istart + 1
9094
jend = input_height - jstart + 1
9195

92-
convolution: do concurrent(i = istart:iend, j = jstart:jend)
96+
! convolution: do concurrent(i = istart:iend, j = jstart:jend)
97+
convolution: do concurrent(i = 1:self % width, j = 1:self%height)
9398

9499
! Start and end indices of the input data on the filter window
95100
! iws and jws are also coincidentally the indices of the output matrix
96-
iws = i - half_window ! TODO kernel_width
97-
iwe = i + half_window ! TODO kernel_width
98-
jws = j - half_window ! TODO kernel_height
99-
jwe = j + half_window ! TODO kernel_height
101+
iws = istart + self %stride(1) * (i-1) - half_window ! TODO kernel_width
102+
iwe = min(iws + 2*half_window, input_width) ! TODO kernel_width
103+
104+
jws = jstart + self %stride(2) * (j-1) - half_window ! TODO kernel_height
105+
jwe = min(jws + 2*half_window, input_height) ! TODO kernel_height
100106

101107
! Compute the inner tensor product, sum(w_ij * x_ij), for each filter.
102108
do concurrent(n = 1:self % filters)
103-
self % z(n,iws,jws) = sum(self % kernel(n,:,:,:) * input(:,iws:iwe,jws:jwe))
109+
self % z(n,i,j) = sum(self % kernel(n,:,1:iwe-iws+1,1:jwe-jws+1) * input(:,iws:iwe,jws:jwe))
104110
end do
105111

106112
! Add bias to the inner product.
107-
self % z(:,iws,jws) = self % z(:,iws,jws) + self % biases
113+
self % z(:,i,j) = self % z(:,i,j) + self % biases
108114

109115
end do convolution
110116

@@ -160,21 +166,28 @@ pure module subroutine backward(self, input, gradient)
160166
do concurrent( &
161167
n = 1:self % filters, &
162168
k = 1:self % channels, &
163-
i = istart:iend, &
164-
j = jstart:jend &
169+
i = 1:self % width, &
170+
j = 1:self % height &
171+
!i = istart:iend, &
172+
!j = jstart:jend &
165173
)
166174
! Start and end indices of the input data on the filter window
167-
iws = i - half_window ! TODO kernel_width
168-
iwe = i + half_window ! TODO kernel_width
169-
jws = j - half_window ! TODO kernel_height
170-
jwe = j + half_window ! TODO kernel_height
175+
!iws = i - half_window ! TODO kernel_width
176+
!iwe = i + half_window ! TODO kernel_width
177+
!jws = j - half_window ! TODO kernel_height
178+
!jwe = j + half_window ! TODO kernel_height
179+
iws = istart + self %stride(1) * (i-1) - half_window ! TODO kernel_width
180+
iwe = min(iws + 2*half_window, input_width) ! TODO kernel_width
181+
182+
jws = jstart + self %stride(2) * (j-1) - half_window ! TODO kernel_height
183+
jwe = min(jws + 2*half_window, input_height) ! TODO kernel_height
171184

172185
! dL/dw = sum(dL/dy * sigma'(z) * x)
173186
dw(n,k,:,:) = dw(n,k,:,:) + input(k,iws:iwe,jws:jwe) * gdz(n,iws:iwe,jws:jwe)
174187

175188
! dL/dx = dL/dy * sigma'(z) .inner. w
176-
self % gradient(k,i,j) = self % gradient(k,i,j) &
177-
+ sum(gdz(n,iws:iwe,jws:jwe) * self % kernel(n,k,:,:))
189+
self % gradient(k,iws:iwe,jws:jwe) = self % gradient(k,iws:iwe,jws:jwe) &
190+
+ gdz(n,iws:iwe,jws:jwe) * self % kernel(n,k,1:iwe-iws+1,1:jwe-jws+1)
178191

179192
end do
180193

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ module function conv2d(filters, kernel_width, kernel_height, activation, stride)
6767
integer, intent(in), optional :: stride(:)
6868
type(layer) :: res
6969

70-
integer :: stride_tmp(2)
70+
integer, allocatable :: stride_tmp(:)
7171
class(activation_function), allocatable :: activation_tmp
7272

7373
! Enforce kernel_width == kernel_height for now;
@@ -76,12 +76,6 @@ module function conv2d(filters, kernel_width, kernel_height, activation, stride)
7676
if (kernel_width /= kernel_height) &
7777
error stop 'kernel_width must equal kernel_height in a conv2d layer'
7878

79-
if (size(stride) /= 2 ) &
80-
error stop 'size of stride must be equal to 2 in a conv2d layer'
81-
82-
if (stride(1) < 1 .or. stride(2) < 1) &
83-
error stop 'stride must be >= 1 in a conv2d layer'
84-
8579
res % name = 'conv2d'
8680

8781
if (present(activation)) then
@@ -98,9 +92,15 @@ module function conv2d(filters, kernel_width, kernel_height, activation, stride)
9892
stride_tmp = [1, 1]
9993
endif
10094

95+
if (size(stride_tmp) /= 2 ) &
96+
error stop 'size of stride must be equal to 2 in a conv2d layer'
97+
98+
if (stride_tmp(1) < 1 .or. stride_tmp(2) < 1) &
99+
error stop 'stride must be >= 1 in a conv2d layer'
100+
101101
allocate( &
102102
res % p, &
103-
source=conv2d_layer(filters, kernel_width, activation_tmp, stride) &
103+
source=conv2d_layer(filters, kernel_width, activation_tmp, stride_tmp) &
104104
)
105105

106106
end function conv2d

test/test_conv2d_layer.f90

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,30 @@ program test_conv2d_layer
5959
call this_layer % set(sample_input)
6060
end select
6161

62+
deallocate(sample_input)
63+
64+
call conv_layer % forward(input_layer)
65+
call conv_layer % get_output(output)
66+
67+
if (.not. all(abs(output) < tolerance)) then
68+
ok = .false.
69+
write(stderr, '(a)') 'conv2d layer with zero input and sigmoid function must forward to all 0.5.. failed'
70+
end if
71+
72+
! Minimal conv2d layer: 1 channel, 17x17 pixel image, stride=3;
73+
allocate(sample_input(1, 17, 17))
74+
sample_input = 0
75+
76+
input_layer = input(1, 17, 17)
77+
conv_layer = conv(filters, kernel_size, kernel_size, stride=[3, 4])
78+
call conv_layer % init(input_layer)
79+
80+
select type(this_layer => input_layer % p); type is(input3d_layer)
81+
call this_layer % set(sample_input)
82+
end select
83+
84+
deallocate(sample_input)
85+
6286
call conv_layer % forward(input_layer)
6387
call conv_layer % get_output(output)
6488

@@ -67,6 +91,7 @@ program test_conv2d_layer
6791
write(stderr, '(a)') 'conv2d layer with zero input and sigmoid function must forward to all 0.5.. failed'
6892
end if
6993

94+
! Summary
7095
if (ok) then
7196
print '(a)', 'test_conv2d_layer: All tests passed.'
7297
else

0 commit comments

Comments
 (0)