Skip to content

Commit 4c7c0b9

Browse files
committed
fix: Conciliating with latest main state
1 parent 5bc9bc5 commit 4c7c0b9

File tree

1 file changed

+1
-152
lines changed

1 file changed

+1
-152
lines changed

src/nf/nf_network_submodule.f90

Lines changed: 1 addition & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
use nf_maxpool2d_layer, only: maxpool2d_layer
99
use nf_reshape_layer, only: reshape3d_layer
1010
use nf_rnn_layer, only: rnn_layer
11-
use nf_io_hdf5, only: get_hdf5_dataset
12-
use nf_keras, only: get_keras_h5_layers, keras_layer
1311
use nf_layer, only: layer
1412
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape, rnn
1513
use nf_loss, only: quadratic
@@ -96,155 +94,6 @@ module function network_from_layers(layers) result(res)
9694
end function network_from_layers
9795

9896

99-
module function network_from_keras(filename) result(res)
100-
character(*), intent(in) :: filename
101-
type(network) :: res
102-
type(keras_layer), allocatable :: keras_layers(:)
103-
type(layer), allocatable :: layers(:)
104-
character(:), allocatable :: layer_name
105-
character(:), allocatable :: object_name
106-
integer :: n
107-
108-
keras_layers = get_keras_h5_layers(filename)
109-
110-
allocate(layers(size(keras_layers)))
111-
112-
do n = 1, size(layers)
113-
114-
select case(keras_layers(n) % class)
115-
116-
case('Conv2D')
117-
118-
if (keras_layers(n) % kernel_size(1) &
119-
/= keras_layers(n) % kernel_size(2)) &
120-
error stop 'Non-square kernel in conv2d layer not supported.'
121-
122-
layers(n) = conv2d( &
123-
keras_layers(n) % filters, &
124-
!FIXME add support for non-square kernel
125-
keras_layers(n) % kernel_size(1), &
126-
get_activation_by_name(keras_layers(n) % activation) &
127-
)
128-
129-
case('Dense')
130-
131-
layers(n) = dense( &
132-
keras_layers(n) % units(1), &
133-
get_activation_by_name(keras_layers(n) % activation) &
134-
)
135-
136-
case('Flatten')
137-
layers(n) = flatten()
138-
139-
case('InputLayer')
140-
if (size(keras_layers(n) % units) == 1) then
141-
! input1d
142-
layers(n) = input(keras_layers(n) % units(1))
143-
else
144-
! input3d
145-
layers(n) = input(keras_layers(n) % units)
146-
end if
147-
148-
case('MaxPooling2D')
149-
150-
if (keras_layers(n) % pool_size(1) &
151-
/= keras_layers(n) % pool_size(2)) &
152-
error stop 'Non-square pool in maxpool2d layer not supported.'
153-
154-
if (keras_layers(n) % strides(1) &
155-
/= keras_layers(n) % strides(2)) &
156-
error stop 'Unequal strides in maxpool2d layer are not supported.'
157-
158-
layers(n) = maxpool2d( &
159-
!FIXME add support for non-square pool and stride
160-
keras_layers(n) % pool_size(1), &
161-
keras_layers(n) % strides(1) &
162-
)
163-
164-
case('Reshape')
165-
layers(n) = reshape(keras_layers(n) % target_shape)
166-
167-
case default
168-
error stop 'This Keras layer is not supported'
169-
170-
end select
171-
172-
end do
173-
174-
res = network(layers)
175-
176-
! Loop over layers and read weights and biases from the Keras h5 file
177-
! for each; currently only dense layers are implemented.
178-
do n = 2, size(res % layers)
179-
180-
layer_name = keras_layers(n) % name
181-
182-
select type(this_layer => res % layers(n) % p)
183-
184-
type is(conv2d_layer)
185-
! Read biases from file
186-
object_name = '/model_weights/' // layer_name // '/' &
187-
// layer_name // '/bias:0'
188-
call get_hdf5_dataset(filename, object_name, this_layer % biases)
189-
190-
! Read weights from file
191-
object_name = '/model_weights/' // layer_name // '/' &
192-
// layer_name // '/kernel:0'
193-
call get_hdf5_dataset(filename, object_name, this_layer % kernel)
194-
195-
type is(dense_layer)
196-
197-
! Read biases from file
198-
object_name = '/model_weights/' // layer_name // '/' &
199-
// layer_name // '/bias:0'
200-
call get_hdf5_dataset(filename, object_name, this_layer % biases)
201-
202-
! Read weights from file
203-
object_name = '/model_weights/' // layer_name // '/' &
204-
// layer_name // '/kernel:0'
205-
call get_hdf5_dataset(filename, object_name, this_layer % weights)
206-
207-
type is(flatten_layer)
208-
! Nothing to do
209-
continue
210-
211-
type is(maxpool2d_layer)
212-
! Nothing to do
213-
continue
214-
215-
type is(reshape3d_layer)
216-
! Nothing to do
217-
continue
218-
219-
type is(rnn_layer)
220-
221-
! Read biases from file
222-
object_name = '/model_weights/' // layer_name // '/' &
223-
// layer_name // '/simple_rnn_cell_23/bias:0'
224-
call get_hdf5_dataset(filename, object_name, this_layer % biases)
225-
226-
! Read weights from file
227-
object_name = '/model_weights/' // layer_name // '/' &
228-
// layer_name // '/simple_rnn_cell_23/kernel:0'
229-
call get_hdf5_dataset(filename, object_name, this_layer % weights)
230-
231-
! Read recurrent weights from file
232-
object_name = '/model_weights/' // layer_name // '/' &
233-
// layer_name // '/simple_rnn_cell_23/recurrent_kernel:0'
234-
call get_hdf5_dataset(filename, object_name, this_layer % recurrent)
235-
236-
class default
237-
error stop 'Internal error in network_from_keras(); ' &
238-
// 'mismatch in layer types between the Keras and ' &
239-
// 'neural-fortran model layers.'
240-
241-
end select
242-
243-
end do
244-
245-
end function network_from_keras
246-
247-
24897
pure function get_activation_by_name(activation_name) result(res)
24998
! Workaround to get activation_function with some
25099
! hardcoded default parameters by its name.
@@ -298,7 +147,7 @@ pure function get_activation_by_name(activation_name) result(res)
298147

299148
end function get_activation_by_name
300149

301-
pure module subroutine backward(self, output, loss)
150+
module subroutine backward(self, output, loss)
302151
class(network), intent(in out) :: self
303152
real, intent(in) :: output(:)
304153
class(loss_type), intent(in), optional :: loss

0 commit comments

Comments
 (0)