Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

several improvements #80

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
109 changes: 109 additions & 0 deletions src/fann_train.c
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,55 @@ fann_type fann_update_MSE(struct fann *ann, struct fann_neuron* neuron, fann_typ
return neuron_diff;
}

fann_type fann_update_MSE_lw(struct fann *ann, struct fann_neuron* neuron, fann_type neuron_diff, fann_type label_weight)
{
float neuron_diff2;

switch (neuron->activation_function)
{
case FANN_LINEAR_PIECE_SYMMETRIC:
case FANN_THRESHOLD_SYMMETRIC:
case FANN_SIGMOID_SYMMETRIC:
case FANN_SIGMOID_SYMMETRIC_STEPWISE:
case FANN_ELLIOT_SYMMETRIC:
case FANN_GAUSSIAN_SYMMETRIC:
case FANN_SIN_SYMMETRIC:
case FANN_COS_SYMMETRIC:
neuron_diff /= (fann_type)2.0;
break;
case FANN_THRESHOLD:
case FANN_LINEAR:
case FANN_SIGMOID:
case FANN_SIGMOID_STEPWISE:
case FANN_GAUSSIAN:
case FANN_GAUSSIAN_STEPWISE:
case FANN_ELLIOT:
case FANN_LINEAR_PIECE:
case FANN_SIN:
case FANN_COS:
break;
}

#ifdef FIXEDFANN
neuron_diff2 =
(neuron_diff / (float) ann->multiplier) * (neuron_diff / (float) ann->multiplier);
#else
neuron_diff2 = (float) (neuron_diff * neuron_diff);
#endif

ann->MSE_value += neuron_diff2 * label_weight;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got no experience with the FIXEDFANN mode, do we have to divide label_weight by (float) ann->multiplier here?


/*printf("neuron_diff %f = (%f - %f)[/2], neuron_diff2=%f, sum=%f, MSE_value=%f, num_MSE=%d\n", neuron_diff, *desired_output, neuron_value, neuron_diff2, last_layer_begin->sum, ann->MSE_value, ann->num_MSE); */
if(fann_abs(neuron_diff) >= ann->bit_fail_limit)
{
ann->num_bit_fail++;
}

return neuron_diff;
}



/* Tests the network.
*/
FANN_EXTERNAL fann_type *FANN_API fann_test(struct fann *ann, fann_type * input,
Expand Down Expand Up @@ -284,6 +333,66 @@ void fann_compute_MSE(struct fann *ann, fann_type * desired_output)
}
}

void fann_compute_MSE_lw(struct fann *ann, fann_type * desired_output, fann_type label_weight)
{
fann_type neuron_value, neuron_diff, *error_it = 0, *error_begin = 0;
struct fann_neuron *last_layer_begin = (ann->last_layer - 1)->first_neuron;
const struct fann_neuron *last_layer_end = last_layer_begin + ann->num_output;
const struct fann_neuron *first_neuron = ann->first_layer->first_neuron;

/* if no room allocated for the error variabels, allocate it now */
if(ann->train_errors == NULL)
{
ann->train_errors = (fann_type *) calloc(ann->total_neurons, sizeof(fann_type));
if(ann->train_errors == NULL)
{
fann_error((struct fann_error *) ann, FANN_E_CANT_ALLOCATE_MEM);
return;
}
}
else
{
/* clear the error variabels */
memset(ann->train_errors, 0, (ann->total_neurons) * sizeof(fann_type));
}
error_begin = ann->train_errors;

#ifdef DEBUGTRAIN
printf("\ncalculate errors\n");
#endif
/* calculate the error and place it in the output layer */
error_it = error_begin + (last_layer_begin - first_neuron);

for(; last_layer_begin != last_layer_end; last_layer_begin++)
{
neuron_value = last_layer_begin->value;
neuron_diff = *desired_output - neuron_value;

neuron_diff = fann_update_MSE_lw(ann, last_layer_begin, neuron_diff, label_weight);

if(ann->train_error_function)
{ /* TODO make switch when more functions */
if(neuron_diff < -.9999999)
neuron_diff = -17.0;
else if(neuron_diff > .9999999)
neuron_diff = 17.0;
else
neuron_diff = (fann_type) log((1.0 + neuron_diff) / (1.0 - neuron_diff));
}

*error_it = fann_activation_derived(last_layer_begin->activation_function,
last_layer_begin->activation_steepness, neuron_value,
last_layer_begin->sum) * neuron_diff * label_weight;

desired_output++;
error_it++;

ann->num_MSE++;
}
}



/* INTERNAL FUNCTION
Propagate the error backwards from the output layer.

Expand Down
44 changes: 44 additions & 0 deletions src/fann_train_data.c
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,32 @@ float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data)
return fann_get_MSE(ann);
}

float fann_train_epoch_irpropm_lw(struct fann *ann, struct fann_train_data *data, fann_type* label_weight)
{
unsigned int i;

if(ann->prev_train_slopes == NULL)
{
fann_clear_train_arrays(ann);
}

fann_reset_MSE(ann);

for(i = 0; i < data->num_data; i++)
{
fann_run(ann, data->input[i]);
fann_compute_MSE_lw(ann, data->output[i], label_weight[i]);
fann_backpropagate_MSE(ann);
fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
}

fann_update_weights_irpropm(ann, 0, ann->total_connections);

return fann_get_MSE(ann);
}



/*
* Internal train function
*/
Expand Down Expand Up @@ -244,6 +270,24 @@ FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_trai
return 0;
}

FANN_EXTERNAL float FANN_API fann_train_epoch_lw(struct fann *ann, struct fann_train_data *data,
fann_type* label_weight)
{
if(fann_check_input_output_sizes(ann, data) == -1)
return 0;

switch (ann->training_algorithm)
{
case FANN_TRAIN_RPROP:
return fann_train_epoch_irpropm_lw(ann, data, label_weight);
default:
printf("FANN : fann_train_epoch_lw not implemented with others algo\n");

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes more sense to call:

fann_error((struct fann_error *) ann, FANN_E_CANT_USE_TRAIN_ALG);

here and fall through to the return statement.

}
return 0;
}



FANN_EXTERNAL void FANN_API fann_train_on_data(struct fann *ann, struct fann_train_data *data,
unsigned int max_epochs,
unsigned int epochs_between_reports,
Expand Down
1 change: 1 addition & 0 deletions src/include/fann_train.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ FANN_EXTERNAL void FANN_API fann_train_on_file(struct fann *ann, const char *fil
This function appears in FANN >= 1.2.0.
*/
FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_train_data *data);
FANN_EXTERNAL float FANN_API fann_train_epoch_lw(struct fann *ann, struct fann_train_data *data, fann_type* label_weight);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function and the others fann_*_lw() functions should get documentation in the same style as the rest.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We likely also want to add train_epoch_lw() in some way to the FANN::neural_net class in fann_cpp.h.

#endif /* NOT FIXEDFANN */

/* Function: fann_test_data
Expand Down