-
Notifications
You must be signed in to change notification settings - Fork 1
/
create_layer.m
49 lines (39 loc) · 972 Bytes
/
create_layer.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
function layer = create_layer(insize, hidsize, transfcn, W, b, trainfcn, varargin)
%% Parse inputs
p = inputParser;
p.CaseSensitive = false;
% Set opts
p.addParameter('Name', 'Layer')
p.parse(varargin{:});
% Get opts
name = p.Results.Name;
% Initialize net
layer = network;
layer.name = name;
% Dimensions
layer.numInputs = 1;
layer.numLayers = 1;
% Connections
layer.inputConnect(1,1) = 1;
layer.outputConnect = 1;
layer.biasConnect = 1;
% Subobjects
layer.input.size = insize;
layer.layers{1}.name = name;
layer.layers{1}.size = hidsize;
layer.layers{1}.transferFcn = transfcn;
% Weight and bias values
layer.IW{1,1} = W;
layer.b{1} = b;
% Functions
layer.divideFcn = 'dividetrain';
layer.plotFcns = {'plotperform'};
layer.plotParams = {nnetParam}; % Dummy?
layer.trainFcn = trainfcn;
% Set the input
layerStruct = struct(layer);
% TODO
% networkStruct = struct(net);
% layerStruct.inputs{1} = networkStruct.inputs{1};
layer = network(layerStruct);
end