|
8 | 8 | use nf_maxpool2d_layer, only: maxpool2d_layer |
9 | 9 | use nf_reshape_layer, only: reshape3d_layer |
10 | 10 | use nf_rnn_layer, only: rnn_layer |
11 | | - use nf_io_hdf5, only: get_hdf5_dataset |
12 | | - use nf_keras, only: get_keras_h5_layers, keras_layer |
13 | 11 | use nf_layer, only: layer |
14 | 12 | use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape, rnn |
15 | 13 | use nf_loss, only: quadratic |
@@ -96,155 +94,6 @@ module function network_from_layers(layers) result(res) |
96 | 94 | end function network_from_layers |
97 | 95 |
|
98 | 96 |
|
99 | | - module function network_from_keras(filename) result(res) |
100 | | - character(*), intent(in) :: filename |
101 | | - type(network) :: res |
102 | | - type(keras_layer), allocatable :: keras_layers(:) |
103 | | - type(layer), allocatable :: layers(:) |
104 | | - character(:), allocatable :: layer_name |
105 | | - character(:), allocatable :: object_name |
106 | | - integer :: n |
107 | | - |
108 | | - keras_layers = get_keras_h5_layers(filename) |
109 | | - |
110 | | - allocate(layers(size(keras_layers))) |
111 | | - |
112 | | - do n = 1, size(layers) |
113 | | - |
114 | | - select case(keras_layers(n) % class) |
115 | | - |
116 | | - case('Conv2D') |
117 | | - |
118 | | - if (keras_layers(n) % kernel_size(1) & |
119 | | - /= keras_layers(n) % kernel_size(2)) & |
120 | | - error stop 'Non-square kernel in conv2d layer not supported.' |
121 | | - |
122 | | - layers(n) = conv2d( & |
123 | | - keras_layers(n) % filters, & |
124 | | - !FIXME add support for non-square kernel |
125 | | - keras_layers(n) % kernel_size(1), & |
126 | | - get_activation_by_name(keras_layers(n) % activation) & |
127 | | - ) |
128 | | - |
129 | | - case('Dense') |
130 | | - |
131 | | - layers(n) = dense( & |
132 | | - keras_layers(n) % units(1), & |
133 | | - get_activation_by_name(keras_layers(n) % activation) & |
134 | | - ) |
135 | | - |
136 | | - case('Flatten') |
137 | | - layers(n) = flatten() |
138 | | - |
139 | | - case('InputLayer') |
140 | | - if (size(keras_layers(n) % units) == 1) then |
141 | | - ! input1d |
142 | | - layers(n) = input(keras_layers(n) % units(1)) |
143 | | - else |
144 | | - ! input3d |
145 | | - layers(n) = input(keras_layers(n) % units) |
146 | | - end if |
147 | | - |
148 | | - case('MaxPooling2D') |
149 | | - |
150 | | - if (keras_layers(n) % pool_size(1) & |
151 | | - /= keras_layers(n) % pool_size(2)) & |
152 | | - error stop 'Non-square pool in maxpool2d layer not supported.' |
153 | | - |
154 | | - if (keras_layers(n) % strides(1) & |
155 | | - /= keras_layers(n) % strides(2)) & |
156 | | - error stop 'Unequal strides in maxpool2d layer are not supported.' |
157 | | - |
158 | | - layers(n) = maxpool2d( & |
159 | | - !FIXME add support for non-square pool and stride |
160 | | - keras_layers(n) % pool_size(1), & |
161 | | - keras_layers(n) % strides(1) & |
162 | | - ) |
163 | | - |
164 | | - case('Reshape') |
165 | | - layers(n) = reshape(keras_layers(n) % target_shape) |
166 | | - |
167 | | - case default |
168 | | - error stop 'This Keras layer is not supported' |
169 | | - |
170 | | - end select |
171 | | - |
172 | | - end do |
173 | | - |
174 | | - res = network(layers) |
175 | | - |
176 | | - ! Loop over layers and read weights and biases from the Keras h5 file |
177 | | - ! for each; currently only dense layers are implemented. |
178 | | - do n = 2, size(res % layers) |
179 | | - |
180 | | - layer_name = keras_layers(n) % name |
181 | | - |
182 | | - select type(this_layer => res % layers(n) % p) |
183 | | - |
184 | | - type is(conv2d_layer) |
185 | | - ! Read biases from file |
186 | | - object_name = '/model_weights/' // layer_name // '/' & |
187 | | - // layer_name // '/bias:0' |
188 | | - call get_hdf5_dataset(filename, object_name, this_layer % biases) |
189 | | - |
190 | | - ! Read weights from file |
191 | | - object_name = '/model_weights/' // layer_name // '/' & |
192 | | - // layer_name // '/kernel:0' |
193 | | - call get_hdf5_dataset(filename, object_name, this_layer % kernel) |
194 | | - |
195 | | - type is(dense_layer) |
196 | | - |
197 | | - ! Read biases from file |
198 | | - object_name = '/model_weights/' // layer_name // '/' & |
199 | | - // layer_name // '/bias:0' |
200 | | - call get_hdf5_dataset(filename, object_name, this_layer % biases) |
201 | | - |
202 | | - ! Read weights from file |
203 | | - object_name = '/model_weights/' // layer_name // '/' & |
204 | | - // layer_name // '/kernel:0' |
205 | | - call get_hdf5_dataset(filename, object_name, this_layer % weights) |
206 | | - |
207 | | - type is(flatten_layer) |
208 | | - ! Nothing to do |
209 | | - continue |
210 | | - |
211 | | - type is(maxpool2d_layer) |
212 | | - ! Nothing to do |
213 | | - continue |
214 | | - |
215 | | - type is(reshape3d_layer) |
216 | | - ! Nothing to do |
217 | | - continue |
218 | | - |
219 | | - type is(rnn_layer) |
220 | | - |
221 | | - ! Read biases from file |
222 | | - object_name = '/model_weights/' // layer_name // '/' & |
223 | | - // layer_name // '/simple_rnn_cell_23/bias:0' |
224 | | - call get_hdf5_dataset(filename, object_name, this_layer % biases) |
225 | | - |
226 | | - ! Read weights from file |
227 | | - object_name = '/model_weights/' // layer_name // '/' & |
228 | | - // layer_name // '/simple_rnn_cell_23/kernel:0' |
229 | | - call get_hdf5_dataset(filename, object_name, this_layer % weights) |
230 | | - |
231 | | - ! Read recurrent weights from file |
232 | | - object_name = '/model_weights/' // layer_name // '/' & |
233 | | - // layer_name // '/simple_rnn_cell_23/recurrent_kernel:0' |
234 | | - call get_hdf5_dataset(filename, object_name, this_layer % recurrent) |
235 | | - |
236 | | - class default |
237 | | - error stop 'Internal error in network_from_keras(); ' & |
238 | | - // 'mismatch in layer types between the Keras and ' & |
239 | | - // 'neural-fortran model layers.' |
240 | | - |
241 | | - end select |
242 | | - |
243 | | - end do |
244 | | - |
245 | | - end function network_from_keras |
246 | | - |
247 | | - |
248 | 97 | pure function get_activation_by_name(activation_name) result(res) |
249 | 98 | ! Workaround to get activation_function with some |
250 | 99 | ! hardcoded default parameters by its name. |
@@ -298,7 +147,7 @@ pure function get_activation_by_name(activation_name) result(res) |
298 | 147 |
|
299 | 148 | end function get_activation_by_name |
300 | 149 |
|
301 | | - pure module subroutine backward(self, output, loss) |
| 150 | + module subroutine backward(self, output, loss) |
302 | 151 | class(network), intent(in out) :: self |
303 | 152 | real, intent(in) :: output(:) |
304 | 153 | class(loss_type), intent(in), optional :: loss |
|
0 commit comments