Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement saving and loading of pyFFTW wisdom for substantial speedup #11

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 129 additions & 100 deletions nsgt/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,122 +11,151 @@
AudioMiner project, supported by Vienna Science and Technology Fund (WWTF)
"""

import atexit
import numpy as np
import os.path
import pickle
from threading import Timer
from warnings import warn

realized = False

if not realized:
# try to use FFT3 if available, else use numpy.fftpack
# Try engines in order of:
# PyFFTW3
# pyFFTW
# numpy.fftpack
try:
import fftw3, fftw3f
except ImportError:
fftw3 = None
fftw3f = None
try:
import fftw3
import pyfftw
except ImportError:
fftw3 = None

try:
import fftw3f
except ImportError:
fftw3f = None

if fftw3 is not None:
# fftw3 methods
class fftpool:
def __init__(self, measure, dtype=float):
self.measure = measure
self.dtype = np.dtype(dtype)
dtsz = self.dtype.itemsize
if dtsz == 4:
self.tpfloat = np.float32
self.tpcplx = np.complex64
self.fftw = fftw3f
elif dtsz == 8:
self.tpfloat = np.float64
self.tpcplx = np.complex128
self.fftw = fftw3
else:
raise TypeError("nsgt.fftpool: dtype '%s' not supported"%repr(self.dtype))
self.pool = {}

def __call__(self, x, outn=None, ref=False):
lx = len(x)
try:
transform = self.pool[lx]
except KeyError:
transform = self.init(lx, measure=self.measure, outn=outn)
self.pool[lx] = transform
plan,pre,post = transform
if pre is not None:
x = pre(x)
plan.inarray[:] = x
plan()
if not ref:
tx = plan.outarray.copy()
else:
tx = plan.outarray
if post is not None:
tx = post(tx)
return tx

class fftp(fftpool):
def __init__(self, measure=False, dtype=float):
fftpool.__init__(self, measure, dtype=dtype)
def init(self, n, measure, outn):
inp = self.fftw.create_aligned_array(n, dtype=self.tpcplx)
outp = self.fftw.create_aligned_array(n, dtype=self.tpcplx)
plan = self.fftw.Plan(inp, outp, direction='forward', flags=('measure' if measure else 'estimate',))
return (plan,None,None)

class rfftp(fftpool):
def __init__(self, measure=False, dtype=float):
fftpool.__init__(self, measure, dtype=dtype)
def init(self, n, measure, outn):
inp = self.fftw.create_aligned_array(n, dtype=self.tpfloat)
outp = self.fftw.create_aligned_array(n//2+1, dtype=self.tpcplx)
plan = self.fftw.Plan(inp, outp, direction='forward', realtypes='halfcomplex r2c',flags=('measure' if measure else 'estimate',))
return (plan,None,None)

class ifftp(fftpool):
def __init__(self, measure=False, dtype=float):
fftpool.__init__(self, measure, dtype=dtype)
def init(self, n, measure, outn):
inp = self.fftw.create_aligned_array(n, dtype=self.tpcplx)
outp = self.fftw.create_aligned_array(n, dtype=self.tpcplx)
plan = self.fftw.Plan(inp, outp, direction='backward', flags=('measure' if measure else 'estimate',))
return (plan,None,lambda x: x/len(x))

class irfftp(fftpool):
def __init__(self, measure=False, dtype=float):
fftpool.__init__(self, measure, dtype=dtype)
def init(self, n, measure, outn):
inp = self.fftw.create_aligned_array(n, dtype=self.tpcplx)
outp = self.fftw.create_aligned_array(outn if outn is not None else (n-1)*2, dtype=self.tpfloat)
plan = self.fftw.Plan(inp, outp, direction='backward', realtypes='halfcomplex c2r', flags=('measure' if measure else 'estimate',))
return (plan,lambda x: x[:n],lambda x: x/len(x))

realized = True


if not realized:
pyfftw = None


if fftw3 is not None and fftw3f is not None:
ENGINE = "PYFFTW3"
# Use fftw3 methods
class fftpool:
def __init__(self, measure, dtype=float):
self.measure = measure
self.dtype = np.dtype(dtype)
dtsz = self.dtype.itemsize
if dtsz == 4:
self.tpfloat = np.float32
self.tpcplx = np.complex64
self.fftw = fftw3f
elif dtsz == 8:
self.tpfloat = np.float64
self.tpcplx = np.complex128
self.fftw = fftw3
else:
raise TypeError("nsgt.fftpool: dtype '%s' not supported"%repr(self.dtype))
self.pool = {}

def __call__(self, x, outn=None, ref=False):
lx = len(x)
try:
transform = self.pool[lx]
except KeyError:
transform = self.init(lx, measure=self.measure, outn=outn)
self.pool[lx] = transform
plan,pre,post = transform
if pre is not None:
x = pre(x)
plan.inarray[:] = x
plan()
if not ref:
tx = plan.outarray.copy()
else:
tx = plan.outarray
if post is not None:
tx = post(tx)
return tx

class fftp(fftpool):
def __init__(self, measure=False, dtype=float):
fftpool.__init__(self, measure, dtype=dtype)
def init(self, n, measure, outn):
inp = self.fftw.create_aligned_array(n, dtype=self.tpcplx)
outp = self.fftw.create_aligned_array(n, dtype=self.tpcplx)
plan = self.fftw.Plan(inp, outp, direction='forward', flags=('measure' if measure else 'estimate',))
return (plan,None,None)

class rfftp(fftpool):
def __init__(self, measure=False, dtype=float):
fftpool.__init__(self, measure, dtype=dtype)
def init(self, n, measure, outn):
inp = self.fftw.create_aligned_array(n, dtype=self.tpfloat)
outp = self.fftw.create_aligned_array(n//2+1, dtype=self.tpcplx)
plan = self.fftw.Plan(inp, outp, direction='forward', realtypes='halfcomplex r2c',flags=('measure' if measure else 'estimate',))
return (plan,None,None)

class ifftp(fftpool):
def __init__(self, measure=False, dtype=float):
fftpool.__init__(self, measure, dtype=dtype)
def init(self, n, measure, outn):
inp = self.fftw.create_aligned_array(n, dtype=self.tpcplx)
outp = self.fftw.create_aligned_array(n, dtype=self.tpcplx)
plan = self.fftw.Plan(inp, outp, direction='backward', flags=('measure' if measure else 'estimate',))
return (plan,None,lambda x: x/len(x))

class irfftp(fftpool):
def __init__(self, measure=False, dtype=float):
fftpool.__init__(self, measure, dtype=dtype)
def init(self, n, measure, outn):
inp = self.fftw.create_aligned_array(n, dtype=self.tpcplx)
outp = self.fftw.create_aligned_array(outn if outn is not None else (n-1)*2, dtype=self.tpfloat)
plan = self.fftw.Plan(inp, outp, direction='backward', realtypes='halfcomplex c2r', flags=('measure' if measure else 'estimate',))
return (plan,lambda x: x[:n],lambda x: x/len(x))
elif pyfftw is not None:
ENGINE = "PYFFTW"
# Monkey patch in pyFFTW Numpy interface
np.fft = pyfftw.interfaces.numpy_fft
original_fft = np.fft
# Enable cache to keep wisdom, etc.
pyfftw.interfaces.cache.enable()

# Load stored wisdom
PYFFTW_WISDOM_FILENAME = os.path.join(os.path.expanduser("~"), ".nsgt_pyfftw_wisdom.p")
if os.path.isfile(PYFFTW_WISDOM_FILENAME):
with open(PYFFTW_WISDOM_FILENAME, 'rb') as f:
pyfftw.import_wisdom(pickle.load(f))
print("Loaded pyFFTW wisdom from %s" % PYFFTW_WISDOM_FILENAME)

def save_wisdom():
print("Saving pyFFTW wisdom to %s" % PYFFTW_WISDOM_FILENAME)
with open(PYFFTW_WISDOM_FILENAME, 'wb') as f:
pickle.dump(pyfftw.export_wisdom(), f)

# Save wisdom on exit
atexit.register(save_wisdom)
else:
# fall back to numpy methods
warn("nsgt.fft falling back to numpy.fft")

ENGINE = "NUMPY"

if ENGINE in ["PYFFTW", "NUMPY"]:
def get_kwargs(measure):
return ({'planner_effort': 'FFTW_MEASURE' if measure else 'FFTW_ESTIMATE'}
if ENGINE=="PYFFTW" else {})
class fftp:
def __init__(self, measure=False, dtype=float):
pass
self.kwargs = get_kwargs(measure)
def __call__(self,x, outn=None, ref=False):
return np.fft.fft(x)
return np.fft.fft(x, **self.kwargs)
class ifftp:
def __init__(self, measure=False, dtype=float):
pass
self.kwargs = get_kwargs(measure)
def __call__(self,x, outn=None, n=None, ref=False):
return np.fft.ifft(x,n=n)
return np.fft.ifft(x,n=n,**self.kwargs)
class rfftp:
def __init__(self, measure=False, dtype=float):
pass
self.kwargs = get_kwargs(measure)
def __call__(self,x, outn=None, ref=False):
return np.fft.rfft(x)
return np.fft.rfft(x,**self.kwargs)
class irfftp:
def __init__(self, measure=False, dtype=float):
pass
self.kwargs = get_kwargs(measure)
def __call__(self,x,outn=None,ref=False):
return np.fft.irfft(x,n=outn)
return np.fft.irfft(x,n=outn,**self.kwargs)
6 changes: 5 additions & 1 deletion tests/cq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from nsgt import NSGT, OctScale
import unittest

# Make test deterministic
np.random.seed(666)


class TestNSGT(unittest.TestCase):

def test_oct(self):
Expand All @@ -14,7 +18,7 @@ def test_oct(self):
nsgt = NSGT(scale, fs=44100, Ls=len(sig))
c = nsgt.forward(sig)
s_r = nsgt.backward(c)
self.assertTrue(np.allclose(sig, s_r))
self.assertTrue(np.allclose(sig, s_r, atol=1e-07))

def load_tests(*_):
test_cases = unittest.TestSuite()
Expand Down
4 changes: 4 additions & 0 deletions tests/fft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from nsgt.fft import rfftp, irfftp, fftp, ifftp
import unittest

# Make test deterministic
np.random.seed(666)


class TestFFT(unittest.TestCase):
def __init__(self, methodName, n=10000):
super(TestFFT, self).__init__(methodName)
Expand Down
6 changes: 5 additions & 1 deletion tests/nsgt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from nsgt import CQ_NSGT
import unittest

# Make test deterministic
np.random.seed(666)


class Test_CQ_NSGT(unittest.TestCase):

def test_transform(self, length=100000, fmin=50, fmax=22050, bins=12, fs=44100):
Expand All @@ -13,7 +17,7 @@ def test_transform(self, length=100000, fmin=50, fmax=22050, bins=12, fs=44100):
# inverse transform
s_r = nsgt.backward(c)

self.assertTrue(np.allclose(s, s_r))
self.assertTrue(np.allclose(s, s_r, atol=1e-07))

def load_tests(*_):
test_cases = unittest.TestSuite()
Expand Down