Skip to content

Commit

Permalink
ndarray3_fft: implement forward and reverse real FFT
Browse files Browse the repository at this point in the history
This passes `ndarray3`/`ndarray3_complex` to Kiss-FFT.
  • Loading branch information
zmughal committed Nov 2, 2015
1 parent e4968cf commit eb18136
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 43 deletions.
49 changes: 45 additions & 4 deletions lib/ndarray/ndarray3_fft.c
Original file line number Diff line number Diff line change
@@ -1,9 +1,50 @@
#include "ndarray/ndarray3_fft.h"

ndarray3_complex* ndarray3_fftn( ndarray3* n ) {
WARN_UNIMPLEMENTED;
#include <kiss_fft.h>
#include <kiss_fftndr.h>

ndarray3_complex* ndarray3_fftn_r2c( ndarray3* n ) {
ndarray3_complex* freq_data =
ndarray3_complex_new(n->sz[0], n->sz[1], n->sz[2]);

int* dims;
NEW_COUNT( dims, int, PIXEL_NDIMS );
dims[0] = n->sz[0];
dims[1] = n->sz[1];
dims[2] = n->sz[2];
kiss_fftndr_cfg fft =
kiss_fftndr_alloc( dims, PIXEL_NDIMS, false , NULL, NULL);

kiss_fftndr(fft, (kiss_fft_scalar*)(n->p), freq_data->p );

free(dims);
free(fft);

return freq_data;
}

ndarray3* ndarray3_ifftn_symmetric( ndarray3* n ) {
WARN_UNIMPLEMENTED;
ndarray3* ndarray3_ifftn_c2r( ndarray3_complex* n ) {
ndarray3* spatial_data =
ndarray3_new(n->sz[0], n->sz[1], n->sz[2]);

int* dims;
NEW_COUNT( dims, int, PIXEL_NDIMS );
dims[0] = n->sz[0];
dims[1] = n->sz[1];
dims[2] = n->sz[2];

kiss_fftndr_cfg ifft =
kiss_fftndr_alloc( dims, PIXEL_NDIMS, true , NULL, NULL);

kiss_fftndri(ifft, n->p, spatial_data->p );

size_t nelems = ndarray3_complex_elems( n );
for( int i = 0; i < nelems; i++ ) {
spatial_data->p[i] /= nelems; /* scaling */
}

free(dims);
free(ifft);

return spatial_data;
}
8 changes: 3 additions & 5 deletions lib/ndarray/ndarray3_fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@

/* local headers */
#include "ndarray/ndarray3.h"

/* structs, enums */
typedef void ndarray3_complex; /* TODO remove this later */
#include "ndarray/ndarray3_complex.h"

#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */

/* Function prototypes */
extern ndarray3_complex* ndarray3_fftn( ndarray3* n );
extern ndarray3* ndarray3_ifftn_symmetric( ndarray3* n );
extern ndarray3_complex* ndarray3_fftn_r2c( ndarray3* n );
extern ndarray3* ndarray3_ifftn_c2r( ndarray3_complex* n );


#ifdef __cplusplus
Expand Down
42 changes: 8 additions & 34 deletions lib/t/ndarray/ndarray3_fft.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
#include <tap/basic.h>
#include <tap/float.h>

#include <kiss_fft.h>
#include <kiss_fftndr.h>

#include "util/util.h"
#include "ndarray/ndarray3.h"
#include "ndarray/ndarray3_fft.h"

int main(void) {
plan(1);
Expand All @@ -24,32 +21,13 @@ int main(void) {
n->p[idx] = idx;
}

kiss_fft_cpx* freq_data;
kiss_fft_scalar* inv_spatial_data;

NEW_COUNT( freq_data, kiss_fft_cpx, n_nelems );
NEW_COUNT( inv_spatial_data, kiss_fft_scalar, n_nelems );

int dims[3];
dims[0] = n->sz[0];
dims[1] = n->sz[1];
dims[2] = n->sz[2];
kiss_fftndr_cfg fft =
kiss_fftndr_alloc( dims, PIXEL_NDIMS, false , NULL, NULL);
kiss_fftndr_cfg ifft =
kiss_fftndr_alloc( dims, PIXEL_NDIMS, true , NULL, NULL);

kiss_fftndr(fft, (kiss_fft_scalar*)(n->p), freq_data );
kiss_fftndri(ifft, freq_data, inv_spatial_data );

for( int i = 0; i < n_nelems; i++ ) {
inv_spatial_data[i] /= n_nelems; /* scaling */
}
ndarray3_complex* freq_data = ndarray3_fftn_r2c( n );
ndarray3* inv_spatial_data = ndarray3_ifftn_c2r( freq_data );

/*[>DEBUG<]{
printf("Input: \t\tOutput:\n");
for( int i = 0; i < n_nelems; i++ ) {
printf("%f\t\t%f\n",n->p[i], inv_spatial_data[i]);
printf("%f\t\t%f\n",n->p[i], inv_spatial_data->p[i]);
}
}*/

Expand All @@ -61,10 +39,10 @@ int main(void) {
for( int i = 0; i < n_nelems && all_same; i++ ) {
/*[>DEBUG<]printf("[%d] : %g\n", i, fabs(array[i] - buf[i]));*/

all_same = all_same && fabs(n->p[i] - inv_spatial_data[i]) < eps ;
all_same = all_same && fabs(n->p[i] - inv_spatial_data->p[i]) < eps ;

#ifdef ORION_DEBUG
float64 diff = n->p[i] - inv_spatial_data[i];
float64 diff = n->p[i] - inv_spatial_data->p[i];
square_error += SQUARED(diff);
if( max_diff < fabs(diff) ) {
max_diff = fabs(diff);
Expand All @@ -79,14 +57,10 @@ int main(void) {
ok( all_same, "ifft(fft(x)) == x");

ndarray3_free(n);
free(freq_data);
free(inv_spatial_data);

free(fft);
free(ifft);
ndarray3_complex_free(freq_data);
ndarray3_free(inv_spatial_data);

kiss_fft_cleanup();


return EXIT_SUCCESS;
}

0 comments on commit eb18136

Please sign in to comment.