Skip to content

Commit

Permalink
modify function signatures to allow passing back errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jjYBdx4IL authored and jjYBdx4IL committed Jun 16, 2016
1 parent d71d547 commit 62c9de2
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/fann_cascade.c
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ FANN_EXTERNAL void FANN_API fann_cascadetrain_on_file(struct fann *ann, const ch
unsigned int neurons_between_reports,
float desired_error)
{
struct fann_train_data *data = fann_read_train_from_file(filename);
struct fann_train_data *data = fann_read_train_from_file((struct fann_error *) ann, filename);

if(data == NULL)
{
Expand Down
43 changes: 22 additions & 21 deletions src/fann_train_data.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,18 @@
/*
* Reads training data from a file.
*/
FANN_EXTERNAL struct fann_train_data *FANN_API fann_read_train_from_file(const char *configuration_file)
FANN_EXTERNAL struct fann_train_data *FANN_API fann_read_train_from_file(struct fann_error *err, const char *configuration_file)
{
struct fann_train_data *data;
FILE *file = fopen(configuration_file, "r");

if(!file)
{
fann_error(NULL, FANN_E_CANT_OPEN_CONFIG_R, configuration_file);
fann_error(err, FANN_E_CANT_OPEN_CONFIG_R, configuration_file);
return NULL;
}

data = fann_read_train_from_fd(file, configuration_file);
data = fann_read_train_from_fd(err, file, configuration_file);
fclose(file);
return data;
}
Expand Down Expand Up @@ -302,7 +302,7 @@ FANN_EXTERNAL void FANN_API fann_train_on_file(struct fann *ann, const char *fil
unsigned int epochs_between_reports,
float desired_error)
{
struct fann_train_data *data = fann_read_train_from_file(filename);
struct fann_train_data *data = fann_read_train_from_file((struct fann_error *)ann, filename);

if(data == NULL)
{
Expand Down Expand Up @@ -806,7 +806,7 @@ int fann_save_train_internal_fd(struct fann_train_data *data, FILE * file, const
/*
* Creates an empty set of training data
*/
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int num_data, unsigned int num_input, unsigned int num_output)
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(struct fann_error *err, unsigned int num_data, unsigned int num_input, unsigned int num_output)
{
fann_type *data_input, *data_output;
unsigned int i;
Expand All @@ -815,7 +815,7 @@ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int n

if(data == NULL)
{
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
fann_error(err, FANN_E_CANT_ALLOCATE_MEM);
return NULL;
}

Expand All @@ -827,31 +827,31 @@ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int n
data->input = (fann_type **) calloc(num_data, sizeof(fann_type *));
if(data->input == NULL)
{
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
fann_error(err, FANN_E_CANT_ALLOCATE_MEM);
fann_destroy_train(data);
return NULL;
}

data->output = (fann_type **) calloc(num_data, sizeof(fann_type *));
if(data->output == NULL)
{
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
fann_error(err, FANN_E_CANT_ALLOCATE_MEM);
fann_destroy_train(data);
return NULL;
}

data_input = (fann_type *) calloc(num_input * num_data, sizeof(fann_type));
if(data_input == NULL)
{
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
fann_error(err, FANN_E_CANT_ALLOCATE_MEM);
fann_destroy_train(data);
return NULL;
}

data_output = (fann_type *) calloc(num_output * num_data, sizeof(fann_type));
if(data_output == NULL)
{
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
fann_error(err, FANN_E_CANT_ALLOCATE_MEM);
fann_destroy_train(data);
return NULL;
}
Expand All @@ -866,11 +866,11 @@ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int n
return data;
}

FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_pointer_array(unsigned int num_data, unsigned int num_input, fann_type **input, unsigned int num_output, fann_type **output)
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_pointer_array(struct fann_error *err, unsigned int num_data, unsigned int num_input, fann_type **input, unsigned int num_output, fann_type **output)
{
unsigned int i;
struct fann_train_data *data;
data = fann_create_train(num_data, num_input, num_output);
data = fann_create_train(err, num_data, num_input, num_output);

if(data == NULL)
return NULL;
Expand All @@ -884,11 +884,11 @@ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_pointer_array(
return data;
}

FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_array(unsigned int num_data, unsigned int num_input, fann_type *input, unsigned int num_output, fann_type *output)
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_array(struct fann_error *err, unsigned int num_data, unsigned int num_input, fann_type *input, unsigned int num_output, fann_type *output)
{
unsigned int i;
struct fann_train_data *data;
data = fann_create_train(num_data, num_input, num_output);
data = fann_create_train(err, num_data, num_input, num_output);

if(data == NULL)
return NULL;
Expand All @@ -906,7 +906,8 @@ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_array(unsigned
/*
* Creates training data from a callback function.
*/
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_from_callback(unsigned int num_data,
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_from_callback(struct fann_error *err,
unsigned int num_data,
unsigned int num_input,
unsigned int num_output,
void (FANN_API *user_function)( unsigned int,
Expand All @@ -916,7 +917,7 @@ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_from_callback(
fann_type * ))
{
unsigned int i;
struct fann_train_data *data = fann_create_train(num_data, num_input, num_output);
struct fann_train_data *data = fann_create_train(err, num_data, num_input, num_output);
if(data == NULL)
{
return NULL;
Expand Down Expand Up @@ -948,20 +949,20 @@ FANN_EXTERNAL fann_type * FANN_API fann_get_train_output(struct fann_train_data
/*
* INTERNAL FUNCTION Reads training data from a file descriptor.
*/
struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filename)
struct fann_train_data *fann_read_train_from_fd(struct fann_error *err, FILE * file, const char *filename)
{
unsigned int num_input, num_output, num_data, i, j;
unsigned int line = 1;
struct fann_train_data *data;

if(fscanf(file, "%u %u %u\n", &num_data, &num_input, &num_output) != 3)
{
fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
fann_error(err, FANN_E_CANT_READ_TD, filename, line);
return NULL;
}
line++;

data = fann_create_train(num_data, num_input, num_output);
data = fann_create_train(err, num_data, num_input, num_output);
if(data == NULL)
{
return NULL;
Expand All @@ -973,7 +974,7 @@ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filenam
{
if(fscanf(file, FANNSCANF " ", &data->input[i][j]) != 1)
{
fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
fann_error(err, FANN_E_CANT_READ_TD, filename, line);
fann_destroy_train(data);
return NULL;
}
Expand All @@ -984,7 +985,7 @@ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filenam
{
if(fscanf(file, FANNSCANF " ", &data->output[i][j]) != 1)
{
fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
fann_error(err, FANN_E_CANT_READ_TD, filename, line);
fann_destroy_train(data);
return NULL;
}
Expand Down
2 changes: 1 addition & 1 deletion src/include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
/* #undef PACKAGE */

/* Version number of package */
#define VERSION "2.2.0"
/* #undef VERSION */

/* Define for the x86_64 CPU famyly */
/* #undef X86_64 */
2 changes: 1 addition & 1 deletion src/include/fann_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void fann_error(struct fann_error *errdat, const enum fann_errno_enum errno_f, .
void fann_init_error_data(struct fann_error *errdat);

struct fann *fann_create_from_fd(FILE * conf, const char *configuration_file);
struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filename);
struct fann_train_data *fann_read_train_from_fd(struct fann_error * err, FILE * file, const char *filename);

void fann_compute_MSE(struct fann *ann, fann_type * desired_output);
void fann_update_output_weights(struct fann *ann);
Expand Down
11 changes: 6 additions & 5 deletions src/include/fann_train.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ FANN_EXTERNAL float FANN_API fann_test_data(struct fann *ann, struct fann_train_
This function appears in FANN >= 1.0.0
*/
FANN_EXTERNAL struct fann_train_data *FANN_API fann_read_train_from_file(const char *filename);
FANN_EXTERNAL struct fann_train_data *FANN_API fann_read_train_from_file(struct fann_error * err, const char *filename);


/* Function: fann_create_train
Expand All @@ -261,7 +261,7 @@ FANN_EXTERNAL struct fann_train_data *FANN_API fann_read_train_from_file(const c
This function appears in FANN >= 2.2.0
*/
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int num_data, unsigned int num_input, unsigned int num_output);
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(struct fann_error *err, unsigned int num_data, unsigned int num_input, unsigned int num_output);

/* Function: fann_create_train_pointer_array
Creates an training data struct and fills it with data from provided arrays of pointer.
Expand All @@ -276,7 +276,7 @@ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int n
This function appears in FANN >= 2.3.0
*/
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_pointer_array(unsigned int num_data, unsigned int num_input, fann_type **input, unsigned int num_output, fann_type **output);
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_pointer_array(struct fann_error *err, unsigned int num_data, unsigned int num_input, fann_type **input, unsigned int num_output, fann_type **output);

/* Function: fann_create_train_array
Creates an training data struct and fills it with data from provided arrays, where the arrays must have the dimensions:
Expand All @@ -293,7 +293,7 @@ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_pointer_array(
This function appears in FANN >= 2.3.0
*/
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_array(unsigned int num_data, unsigned int num_input, fann_type *input, unsigned int num_output, fann_type *output);
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_array(struct fann_error *err, unsigned int num_data, unsigned int num_input, fann_type *input, unsigned int num_output, fann_type *output);

/* Function: fann_create_train_from_callback
Creates the training data struct from a user supplied function.
Expand Down Expand Up @@ -325,7 +325,8 @@ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_array(unsigned
This function appears in FANN >= 2.1.0
*/
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_from_callback(unsigned int num_data,
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_from_callback(struct fann_error *err,
unsigned int num_data,
unsigned int num_input,
unsigned int num_output,
void (FANN_API *user_function)( unsigned int,
Expand Down
8 changes: 4 additions & 4 deletions src/include/fann_training_data_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ namespace FANN {
*/
bool read_train_from_file(const std::string &filename) {
destroy_train();
train_data = fann_read_train_from_file(filename.c_str());
train_data = fann_read_train_from_file(NULL, filename.c_str());
return (train_data != NULL);
}

Expand Down Expand Up @@ -333,7 +333,7 @@ namespace FANN {
void set_train_data(unsigned int num_data,
unsigned int num_input, fann_type **input,
unsigned int num_output, fann_type **output) {
set_train_data(fann_create_train_pointer_array(num_data, num_input, input, num_output, output));
set_train_data(fann_create_train_pointer_array(NULL, num_data, num_input, input, num_output, output));
}

/* Method: set_train_data
Expand All @@ -357,7 +357,7 @@ namespace FANN {
void set_train_data(unsigned int num_data,
unsigned int num_input, fann_type *input,
unsigned int num_output, fann_type *output) {
set_train_data(fann_create_train_array(num_data, num_input, input, num_output, output));
set_train_data(fann_create_train_array(NULL, num_data, num_input, input, num_output, output));
}

private:
Expand Down Expand Up @@ -406,7 +406,7 @@ namespace FANN {
fann_type *,
fann_type *)) {
destroy_train();
train_data = fann_create_train_from_callback(num_data, num_input, num_output, user_function);
train_data = fann_create_train_from_callback(NULL, num_data, num_input, num_output, user_function);
}

#ifndef FIXEDFANN
Expand Down

0 comments on commit 62c9de2

Please sign in to comment.