Skip to content

Commit

Permalink
Fixed FFTW3&MKL wrapper, working on concrete-fft
Browse files Browse the repository at this point in the history
  • Loading branch information
nindanaoto committed Jul 4, 2024
1 parent 2fa21d9 commit 29c2b44
Show file tree
Hide file tree
Showing 11 changed files with 250 additions and 76 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "thirdparties/hexl/hexl"]
path = thirdparties/hexl/hexl
url = https://github.com/intel/hexl.git
[submodule "thirdparties/concrete-fft"]
path = thirdparties/concrete-fft
url = https://github.com/virtualsecureplatform/concrete-fft.git
12 changes: 9 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,6 @@ if(NOT USE_TERNARY)
endif()
endif()

if(USE_MKL)
set(USE_FFTW3 ON)
endif()

if(USE_AVX512)
string(APPEND CMAKE_CXX_FLAGS " -mprefer-vector-width=512")
Expand All @@ -110,6 +107,15 @@ if(USE_FFTW3)
PARENT_SCOPE)
add_compile_definitions(USE_FFTW3)
add_subdirectory(thirdparties/fftw)
elseif(USE_MKL)
set(TFHEpp_DEFINITIONS
"${TFHEpp_DEFINITIONS};USE_MKL"
PARENT_SCOPE)
add_compile_definitions(USE_MKL)
add_compile_definitions(USE_INTERLEAVED_FORMAT)
find_package(MKL CONFIG REQUIRED PATHS $ENV{MKLROOT})
include_directories(${MKLROOT}/include)
add_subdirectory(thirdparties/mkl)
elseif(USE_SPQLIOX_AARCH64)
set(TFHEpp_DEFINITIONS
"${TFHEpp_DEFINITIONS};USE_SPQLIOX_AARCH64"
Expand Down
22 changes: 22 additions & 0 deletions include/mulfft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "INTorus.hpp"
#ifdef USE_FFTW3
#include <fft_processor_fftw.h>
#elif USE_MKL
#include <fft_processor_mkl.hpp>
#elif USE_SPQLIOX_AARCH64
#include <fft_processor_spqliox_aarch64.h>
#else
Expand All @@ -14,6 +16,10 @@
#include "hexl/hexl.hpp"
#endif

#ifdef USE_INTERLEAVED_FORMAT
#include <complex>
#endif

#include "cuhe++.hpp"
#include "params.hpp"
#include "utils.hpp"
Expand Down Expand Up @@ -141,12 +147,20 @@ template <uint32_t N>
inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &a,
const std::array<double, N> &b)
{
#ifdef USE_INTERLEAVED_FORMAT
for(int i = 0; i < N / 2; i++){
const std::complex tmp = std::complex(a[2*i], a[2*i+1]) * std::complex(b[2*i], b[2*i+1]);
res[2*i] = tmp.real();
res[2*i+1] = tmp.imag();
}
#else
for (int i = 0; i < N / 2; i++) {
double aimbim = a[i + N / 2] * b[i + N / 2];
double arebim = a[i] * b[i + N / 2];
res[i] = std::fma(a[i], b[i], -aimbim);
res[i + N / 2] = std::fma(a[i + N / 2], b[i], arebim);
}
#endif
}

// Be careful about memory accesss (We assume b has relatively high memory
Expand All @@ -155,6 +169,13 @@ template <uint32_t N>
inline void FMAInFD(std::array<double, N> &res, const std::array<double, N> &a,
const std::array<double, N> &b)
{
#ifdef USE_INTERLEAVED_FORMAT
for(int i = 0; i < N / 2; i++){
std::complex tmp = std::complex(a[2*i], a[2*i+1]) * std::complex(b[2*i], b[2*i+1]);
res[2*i] += tmp.real();
res[2*i+1] += tmp.imag();
}
#else
for (int i = 0; i < N / 2; i++) {
res[i] = std::fma(a[i], b[i], res[i]);
res[i + N / 2] = std::fma(a[i + N / 2], b[i], res[i + N / 2]);
Expand All @@ -169,6 +190,7 @@ inline void FMAInFD(std::array<double, N> &res, const std::array<double, N> &a,
// res[i + N / 2] = std::fma(a[i], b[i + N / 2], res[i + N / 2]);
// res[i + N / 2] = std::fma(a[i + N / 2], b[i], res[i + N / 2]);
// }
#endif
}

template <class P>
Expand Down
24 changes: 11 additions & 13 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,15 @@ if(USE_RANDEN)
target_link_libraries(tfhe++ INTERFACE randen)
endif()

if(USE_MKL)
find_package(MKL CONFIG REQUIRED PATHS $ENV{MKLROOT})
target_link_libraries(tfhe++ PUBLIC MKL::MKL)
endif()

if(USE_FFTW3)
target_link_libraries(tfhe++ INTERFACE fftwproc)
if(USE_MKL)
target_include_directories(tfhe++ PUBLIC ${MKLROOT}/include/fftw)
target_link_libraries(
tfhe++
INTERFACE "-Wl,--start-group"
$ENV{MKLROOT}/interfaces/fftw3xc/libfftw3xc_gnu.a
$ENV{MKLROOT}/lib/intel64/libmkl_cdft_core.a
$ENV{MKLROOT}/lib/intel64/libmkl_intel_lp64.a
$ENV{MKLROOT}/lib/intel64/libmkl_sequential.a
$ENV{MKLROOT}/lib/intel64/libmkl_core.a
$ENV{MKLROOT}/lib/intel64/libmkl_blacs_intelmpi_lp64.a
"-Wl,--end-group"
${CMAKE_DL_LIBS}
pthread)
else()
target_link_libraries(tfhe++ INTERFACE fftw3)
endif()
Expand All @@ -47,5 +40,10 @@ elseif(USE_SPQLIOX_AARCH64)
elseif(USE_HEXL)
target_link_libraries(tfhe++ INTERFACE spqlios HEXL::hexl)
else()
target_link_libraries(tfhe++ INTERFACE spqlios)
if(USE_MKL)
target_link_libraries(tfhe++ INTERFACE mklproc)
target_include_directories(tfhe++ PUBLIC ${MKLROOT}/include ${PROJECT_SOURCE_DIR}/thirdparties/mkl)
else()
target_link_libraries(tfhe++ INTERFACE spqlios)
endif()
endif()
5 changes: 5 additions & 0 deletions test/polymul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ int main()
uniform_int_distribution<uint32_t> Bgdist(0, lvl1param::Bg);
uniform_int_distribution<uint32_t> Torus32dist(0, UINT32_MAX);

#ifdef USE_INTERLEAVED_FORMAT
std::cout << "USE_INTERLEAVED_FORMAT" << std::endl;
#endif


cout << "Start LVL1 test." << endl;
for (int test = 0; test < num_test; test++) {
Polynomial<lvl1param> a;
Expand Down
1 change: 1 addition & 0 deletions thirdparties/concrete-fft
Submodule concrete-fft added at 3c6274
104 changes: 44 additions & 60 deletions thirdparties/fftw/fft_processor_fftw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
FFT_Processor_FFTW::FFT_Processor_FFTW(const int32_t N)
: _2N(2 * N), N(N), Ns2(N / 2)
{
auto in = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);
auto out = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);
plan_forward = fftw_plan_dft_1d(Ns2, in, out, FFTW_FORWARD, FFTW_MEASURE);
plan_backward = fftw_plan_dft_1d(Ns2, in, out, FFTW_BACKWARD, FFTW_MEASURE);
fftw_free(in);
fftw_free(out);
inbuf = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);
outbuf = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);
plan_forward = fftw_plan_dft_1d(Ns2, inbuf, outbuf, FFTW_FORWARD, FFTW_MEASURE);
plan_backward = fftw_plan_dft_1d(Ns2, inbuf, outbuf, FFTW_BACKWARD, FFTW_MEASURE);

for (int i = 0; i < Ns2; i++) {
double value = (double)i * M_PI / (double)N;
Expand All @@ -30,22 +28,16 @@ FFT_Processor_FFTW::FFT_Processor_FFTW(const int32_t N)

void FFT_Processor_FFTW::execute_reverse_int(double *res, const int32_t *a)
{
auto in = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);
auto out = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);

for (int i = 0; i < Ns2; i++) {
auto tmp = twist[i] * std::complex((double)a[i], (double)a[Ns2 + i]);
in[i][0] = tmp.real();
in[i][1] = tmp.imag();
inbuf[i][0] = tmp.real();
inbuf[i][1] = tmp.imag();
}
fftw_execute_dft(plan_forward, in, out);
fftw_execute_dft(plan_forward, inbuf, outbuf);
for (int i = 0; i < Ns2; i++) {
res[i] = out[i][0];
res[i + Ns2] = out[i][1];
res[i] = outbuf[i][0];
res[i + Ns2] = outbuf[i][1];
}

fftw_free(in);
fftw_free(out);
}

void FFT_Processor_FFTW::execute_reverse_torus32(double *res, const uint32_t *a)
Expand All @@ -55,90 +47,69 @@ void FFT_Processor_FFTW::execute_reverse_torus32(double *res, const uint32_t *a)

void FFT_Processor_FFTW::execute_reverse_torus64(double *res, const uint64_t *a)
{
auto in = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);
auto out = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);

for (int i = 0; i < Ns2; i++) {
auto tmp = twist[i] * std::complex((double)((int64_t)a[i]),
(double)((int64_t)a[Ns2 + i]));
in[i][0] = tmp.real();
in[i][1] = tmp.imag();
inbuf[i][0] = tmp.real();
inbuf[i][1] = tmp.imag();
}
fftw_execute_dft(plan_forward, in, out);
fftw_execute_dft(plan_forward, inbuf, outbuf);
for (int i = 0; i < Ns2; i++) {
res[i] = out[i][0];
res[i + Ns2] = out[i][1];
res[i] = outbuf[i][0];
res[i + Ns2] = outbuf[i][1];
}

fftw_free(in);
fftw_free(out);
}

void FFT_Processor_FFTW::execute_direct_torus32(uint32_t *res, const double *a)
{
auto in = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);
auto out = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);

for (int i = 0; i < Ns2; i++) {
in[i][0] = a[i] / Ns2;
in[i][1] = a[Ns2 + i] / Ns2;
inbuf[i][0] = a[i] / Ns2;
inbuf[i][1] = a[Ns2 + i] / Ns2;
}
fftw_execute_dft(plan_backward, in, out);
fftw_execute_dft(plan_backward, inbuf, outbuf);
for (int i = 0; i < Ns2; i++) {
auto res_tmp =
std::complex<double>(out[i][0], out[i][1]) * std::conj(twist[i]);
std::complex<double>(outbuf[i][0], outbuf[i][1]) * std::conj(twist[i]);
res[i] = CAST_DOUBLE_TO_UINT32(res_tmp.real());
res[i + Ns2] = CAST_DOUBLE_TO_UINT32(res_tmp.imag());
}

fftw_free(in);
fftw_free(out);
}

void FFT_Processor_FFTW::execute_direct_torus32_rescale(uint32_t *res,
const double *a,
const double Δ)
{
auto in = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);
auto out = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);

for (int i = 0; i < Ns2; i++) {
in[i][0] = a[i] / Ns2;
in[i][1] = a[Ns2 + i] / Ns2;
inbuf[i][0] = a[i] / Ns2;
inbuf[i][1] = a[Ns2 + i] / Ns2;
}
fftw_execute_dft(plan_backward, in, out);
fftw_execute_dft(plan_backward, inbuf, outbuf);
for (int i = 0; i < Ns2; i++) {
auto res_tmp =
std::complex<double>(out[i][0], out[i][1]) * std::conj(twist[i]);
std::complex<double>(outbuf[i][0], outbuf[i][1]) * std::conj(twist[i]);
res[i] = CAST_DOUBLE_TO_UINT32(res_tmp.real() / (Δ / 4));
res[i + Ns2] = CAST_DOUBLE_TO_UINT32(res_tmp.imag() / (Δ / 4));
}

fftw_free(in);
fftw_free(out);
}

void FFT_Processor_FFTW::execute_direct_torus64(uint64_t *res, const double *a)
{
auto in = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);
auto out = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * Ns2);

for (int i = 0; i < Ns2; i++) {
in[i][0] = a[i] / Ns2;
in[i][1] = a[Ns2 + i] / Ns2;
inbuf[i][0] = a[i] / Ns2;
inbuf[i][1] = a[Ns2 + i] / Ns2;
}
fftw_execute_dft(plan_backward, in, out);
static const uint64_t valmask0 = 0x000FFFFFFFFFFFFFul;
static const uint64_t valmask1 = 0x0010000000000000ul;
static const uint16_t expmask0 = 0x07FFu;
fftw_execute_dft(plan_backward, inbuf, outbuf);
double tmp[N];
for (int i = 0; i < Ns2; i++) {
auto res_tmp =
std::complex<double>(out[i][0], out[i][1]) * std::conj(twist[i]);
std::complex<double>(outbuf[i][0], outbuf[i][1]) * std::conj(twist[i]);
tmp[i] = res_tmp.real();
tmp[i + Ns2] = res_tmp.imag();
}
const uint64_t *const vals = (const uint64_t *)tmp;
constexpr uint64_t valmask0 = 0x000FFFFFFFFFFFFFul;
constexpr uint64_t valmask1 = 0x0010000000000000ul;
constexpr uint16_t expmask0 = 0x07FFu;
for (int i = 0; i < N; i++) {
uint64_t val = (vals[i] & valmask0) | valmask1; // mantissa on 53 bits
uint16_t expo = (vals[i] >> 52) & expmask0; // exponent 11 bits
Expand All @@ -148,9 +119,22 @@ void FFT_Processor_FFTW::execute_direct_torus64(uint64_t *res, const double *a)
uint64_t val2 = trans > 0 ? (val << trans) : (val >> -trans);
res[i] = (vals[i] >> 63) ? -val2 : val2;
}
}

fftw_free(in);
fftw_free(out);
void FFT_Processor_FFTW::execute_direct_torus64_rescale(uint64_t* res, const double* a, const double Δ) {
for (int i = 0; i < Ns2; i++) {
inbuf[i][0] = a[i] / Ns2;
inbuf[i][1] = a[Ns2 + i] / Ns2;
}
fftw_execute_dft(plan_backward, inbuf, outbuf);
double tmp[N];
for (int i = 0; i < Ns2; i++) {
auto res_tmp =
std::complex<double>(outbuf[i][0], outbuf[i][1]) * std::conj(twist[i]);
tmp[i] = res_tmp.real();
tmp[i + Ns2] = res_tmp.imag();
}
for (int i=0; i<N; i++) res[i] = uint64_t(std::round(tmp[i]/(Δ/4)));
}

FFT_Processor_FFTW::~FFT_Processor_FFTW()
Expand Down
4 changes: 4 additions & 0 deletions thirdparties/fftw/fft_processor_fftw.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class FFT_Processor_FFTW {
std::vector<std::complex<double>> twist;
fftw_plan plan_forward;
fftw_plan plan_backward;
fftw_complex * inbuf;
fftw_complex * outbuf;

public:
FFT_Processor_FFTW(const int32_t N);
Expand All @@ -36,6 +38,8 @@ class FFT_Processor_FFTW {

void execute_direct_torus64(uint64_t *res, const double *a);

void execute_direct_torus64_rescale(uint64_t *res, const double *a, const double Δ);

~FFT_Processor_FFTW();
};

Expand Down
7 changes: 7 additions & 0 deletions thirdparties/mkl/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
set(SRCS_MKLPROC fft_processor_mkl.cpp)

set(MKLPROC_HEADERS fft_processor_mkl.hpp)

add_library(mklproc STATIC ${SRCS_MKLPROC} ${MKLPROC_HEADERS})

target_include_directories(mklproc PUBLIC ${PROJECT_SOURCE_DIR}/include)
5 changes: 5 additions & 0 deletions thirdparties/mkl/fft_processor_mkl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "fft_processor_mkl.hpp"

// FFT_Processor_MKL is thread-safe
thread_local FFT_Processor_MKL<TFHEpp::lvl1param::n> fftplvl1;
thread_local FFT_Processor_MKL<TFHEpp::lvl2param::n> fftplvl2;
Loading

0 comments on commit 29c2b44

Please sign in to comment.