Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

several improvements #80

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions src/fann.c
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,7 @@ FANN_EXTERNAL struct fann* FANN_API fann_copy(struct fann* orig)

copy->learning_rate = orig->learning_rate;
copy->learning_momentum = orig->learning_momentum;
copy->learning_l2_norm = orig->learning_l2_norm;
copy->connection_rate = orig->connection_rate;
copy->network_type = orig->network_type;
copy->num_MSE = orig->num_MSE;
Expand Down Expand Up @@ -1288,6 +1289,7 @@ FANN_EXTERNAL void FANN_API fann_print_parameters(struct fann *ann)
printf("Bit fail limit :%8.3f\n", ann->bit_fail_limit);
printf("Learning rate :%8.3f\n", ann->learning_rate);
printf("Learning momentum :%8.3f\n", ann->learning_momentum);
printf("Learning l2 norm :%8.3f\n", ann->learning_l2_norm);
printf("Quickprop decay :%11.6f\n", ann->quickprop_decay);
printf("Quickprop mu :%8.3f\n", ann->quickprop_mu);
printf("RPROP increase factor :%8.3f\n", ann->rprop_increase_factor);
Expand Down Expand Up @@ -1647,6 +1649,7 @@ struct fann *fann_allocate_structure(unsigned int num_layers)
ann->errstr = NULL;
ann->learning_rate = 0.7f;
ann->learning_momentum = 0.0;
ann->learning_l2_norm = 0.0;
ann->total_neurons = 0;
ann->total_connections = 0;
ann->num_input = 0;
Expand Down
2 changes: 2 additions & 0 deletions src/fann_io.c
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ int fann_save_internal_fd(struct fann *ann, FILE * conf, const char *configurati
fprintf(conf, "network_type=%u\n", ann->network_type);

fprintf(conf, "learning_momentum=%f\n", ann->learning_momentum);
fprintf(conf, "learning_l2_norm=%f\n", ann->learning_l2_norm);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any versioning concerns here or are we free to break reading of ANN files that have been written with an older library version?

fprintf(conf, "training_algorithm=%u\n", ann->training_algorithm);
fprintf(conf, "train_error_function=%u\n", ann->train_error_function);
fprintf(conf, "train_stop_function=%u\n", ann->train_stop_function);
Expand Down Expand Up @@ -443,6 +444,7 @@ struct fann *fann_create_from_fd(FILE * conf, const char *configuration_file)
fann_scanf("%u", "network_type", &tmpVal);
ann->network_type = (enum fann_nettype_enum)tmpVal;
fann_scanf("%f", "learning_momentum", &ann->learning_momentum);
fann_scanf("%f", "learning_l2_norm", &ann->learning_l2_norm);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as for fann_save_internal_fd() further up.

fann_scanf("%u", "training_algorithm", &tmpVal);
ann->training_algorithm = (enum fann_train_enum)tmpVal;
fann_scanf("%u", "train_error_function", &tmpVal);
Expand Down
4 changes: 3 additions & 1 deletion src/fann_train.c
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ void fann_update_weights_irpropm(struct fann *ann, unsigned int first_weight, un
for(; i != past_end; i++)
{
prev_step = fann_max(prev_steps[i], (fann_type) 0.0001); /* prev_step may not be zero because then the training will stop */
slope = train_slopes[i];
slope = train_slopes[i] - ann->learning_l2_norm * weights[i];

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only adds the L2 regularization to the irpropm algorithm. The same term should be added to the other fann_update_weights*() functions.

prev_slope = prev_train_slopes[i];

same_sign = prev_slope * slope;
Expand Down Expand Up @@ -1294,3 +1294,5 @@ FANN_GET_SET(float, sarprop_temperature)
FANN_GET_SET(enum fann_stopfunc_enum, train_stop_function)
FANN_GET_SET(fann_type, bit_fail_limit)
FANN_GET_SET(float, learning_momentum)
FANN_GET_SET(float, learning_l2_norm)

17 changes: 17 additions & 0 deletions src/include/fann_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -2686,6 +2686,23 @@ namespace FANN
}
}

float get_learning_l2_norm()
{
float learning_l2_norm = 0.0f;
if (ann != NULL)
{
learning_l2_norm = fann_get_learning_l2_norm(ann);
}
return learning_l2_norm;
}

void set_learning_l2_norm(float learning_l2_norm)
{
if (ann != NULL)
{
fann_set_learning_l2_norm(ann, learning_l2_norm);
}
}
/* Method: get_train_stop_function

Returns the the stop function used during training.
Expand Down
1 change: 1 addition & 0 deletions src/include/fann_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ struct fann

/* The learning momentum used for backpropagation algorithm. */
float learning_momentum;
float learning_l2_norm;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is learning_l2_norm the best name for this parameter? It's actually not the L2 norm, but rather the weight of the L2 norm relative to the error function.


/* the connection rate of the network
* between 0 and 1, 1 meaning fully connected
Expand Down
3 changes: 2 additions & 1 deletion src/include/fann_train.h
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,8 @@ FANN_EXTERNAL float FANN_API fann_get_learning_momentum(struct fann *ann);
This function appears in FANN >= 2.0.0.
*/
FANN_EXTERNAL void FANN_API fann_set_learning_momentum(struct fann *ann, float learning_momentum);

FANN_EXTERNAL float FANN_API fann_get_learning_l2_norm(struct fann *ann);
FANN_EXTERNAL void FANN_API fann_set_learning_l2_norm(struct fann *ann, float learning_l2_norm);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions are external and should be documented.


/* Function: fann_get_activation_function

Expand Down