-
Notifications
You must be signed in to change notification settings - Fork 382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
several improvements #80
base: master
Are you sure you want to change the base?
Changes from 1 commit
4c62609
1b2a009
553a9a5
32cb043
a3863a7
cc68429
e0c83d0
c169f27
4fbe9c8
97882a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -152,6 +152,32 @@ float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data) | |
return fann_get_MSE(ann); | ||
} | ||
|
||
float fann_train_epoch_irpropm_lw(struct fann *ann, struct fann_train_data *data, fann_type* label_weight) | ||
{ | ||
unsigned int i; | ||
|
||
if(ann->prev_train_slopes == NULL) | ||
{ | ||
fann_clear_train_arrays(ann); | ||
} | ||
|
||
fann_reset_MSE(ann); | ||
|
||
for(i = 0; i < data->num_data; i++) | ||
{ | ||
fann_run(ann, data->input[i]); | ||
fann_compute_MSE_lw(ann, data->output[i], label_weight[i]); | ||
fann_backpropagate_MSE(ann); | ||
fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1); | ||
} | ||
|
||
fann_update_weights_irpropm(ann, 0, ann->total_connections); | ||
|
||
return fann_get_MSE(ann); | ||
} | ||
|
||
|
||
|
||
/* | ||
* Internal train function | ||
*/ | ||
|
@@ -244,6 +270,24 @@ FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_trai | |
return 0; | ||
} | ||
|
||
FANN_EXTERNAL float FANN_API fann_train_epoch_lw(struct fann *ann, struct fann_train_data *data, | ||
fann_type* label_weight) | ||
{ | ||
if(fann_check_input_output_sizes(ann, data) == -1) | ||
return 0; | ||
|
||
switch (ann->training_algorithm) | ||
{ | ||
case FANN_TRAIN_RPROP: | ||
return fann_train_epoch_irpropm_lw(ann, data, label_weight); | ||
default: | ||
printf("FANN : fann_train_epoch_lw not implemented with others algo\n"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes more sense to call: fann_error((struct fann_error *) ann, FANN_E_CANT_USE_TRAIN_ALG); here and fall through to the return statement. |
||
} | ||
return 0; | ||
} | ||
|
||
|
||
|
||
FANN_EXTERNAL void FANN_API fann_train_on_data(struct fann *ann, struct fann_train_data *data, | ||
unsigned int max_epochs, | ||
unsigned int epochs_between_reports, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -212,6 +212,7 @@ FANN_EXTERNAL void FANN_API fann_train_on_file(struct fann *ann, const char *fil | |
This function appears in FANN >= 1.2.0. | ||
*/ | ||
FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_train_data *data); | ||
FANN_EXTERNAL float FANN_API fann_train_epoch_lw(struct fann *ann, struct fann_train_data *data, fann_type* label_weight); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function and the others There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We likely also want to add |
||
#endif /* NOT FIXEDFANN */ | ||
|
||
/* Function: fann_test_data | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've got no experience with the FIXEDFANN mode, do we have to divide
label_weight
by(float) ann->multiplier
here?