Skip to content

Commit

Permalink
py unittest and benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
kwsp committed Jan 5, 2025
1 parent 0cc2c76 commit 26fa393
Show file tree
Hide file tree
Showing 8 changed files with 325 additions and 114 deletions.
107 changes: 107 additions & 0 deletions py/bench/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from google_benchmark import State
import google_benchmark as benchmark
import numpy as np

from pyfftconv import oaconvolve, oaconvolve_, convolve, convolve_, hilbert, hilbert_

ARGS_FIR = [[2304, 4352], [165]]


def BM_conv(fn, state: State, x, k, mode: str):
fn(x, k, mode)
while state:
fn(x, k, mode)
state.items_processed = state.iterations * x.size
state.bytes_processed = state.items_processed * x.itemsize


def BM_conv_out(fn, state: State, x, k, out, mode: str):
fn(x, k, out, mode)
while state:
fn(x, k, out, mode)
state.items_processed = state.iterations * x.size
state.bytes_processed = state.items_processed * x.itemsize


def BM_hilbert(fn, state: State, x):
fn(x)
while state:
fn(x)
state.items_processed = state.iterations * x.size
state.bytes_processed = state.items_processed * x.itemsize


def BM_hilbert_out(fn, state: State, x, out):
fn(x, out)
while state:
fn(x, out)
state.items_processed = state.iterations * x.size
state.bytes_processed = state.items_processed * x.itemsize


@benchmark.register
@benchmark.option.args_product(ARGS_FIR)
def BM_oaconvolve_same_double(state: State):
x = np.random.random(state.range(0)).astype(np.float64)
k = np.random.random(state.range(1)).astype(np.float64)
BM_conv(oaconvolve, state, x, k, "same")


@benchmark.register
@benchmark.option.args_product(ARGS_FIR)
def BM_oaconvolve_same_float(state: State):
x = np.random.random(state.range(0)).astype(np.float32)
k = np.random.random(state.range(1)).astype(np.float32)
BM_conv(oaconvolve, state, x, k, "same")


@benchmark.register
@benchmark.option.args_product(ARGS_FIR)
def BM_oaconvolve_same_out_double(state: State):
x = np.random.random(state.range(0)).astype(np.float64)
k = np.random.random(state.range(1)).astype(np.float64)
out = np.zeros_like(x)
BM_conv_out(oaconvolve_, state, x, k, out, "same")


@benchmark.register
@benchmark.option.args_product(ARGS_FIR)
def BM_oaconvolve_same_out_float(state: State):
x = np.random.random(state.range(0)).astype(np.float32)
k = np.random.random(state.range(1)).astype(np.float32)
out = np.zeros_like(x)
BM_conv_out(oaconvolve_, state, x, k, out, "same")


@benchmark.register
@benchmark.option.dense_range(2048, 6144, 1024)
def BM_hilbert_float(state: State):
x = np.random.random(state.range(0)).astype(np.float32)
BM_hilbert(hilbert, state, x)


@benchmark.register
@benchmark.option.dense_range(2048, 6144, 1024)
def BM_hilbert_double(state: State):
x = np.random.random(state.range(0)).astype(np.float64)
BM_hilbert(hilbert, state, x)


@benchmark.register
@benchmark.option.dense_range(2048, 6144, 1024)
def BM_hilbert_out_float(state: State):
x = np.random.random(state.range(0)).astype(np.float32)
out = np.zeros_like(x)
BM_hilbert_out(hilbert_, state, x, out)


@benchmark.register
@benchmark.option.dense_range(2048, 6144, 1024)
def BM_hilbert_out_double(state: State):
x = np.random.random(state.range(0)).astype(np.float64)
out = np.zeros_like(x)
BM_hilbert_out(hilbert_, state, x, out)


if __name__ == "__main__":
benchmark.main()
68 changes: 64 additions & 4 deletions py/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,53 @@ static ConvMode parseConvMode(const std::string &mode) {
throw std::runtime_error("Unsupported convolution mode: " + mode);
}

template <typename T>
static void convolve_(py::array_t<T> a, py::array_t<T> k, py::array_t<T> out,
const std::string &modeStr) {

py::buffer_info bufIn = a.request();
py::buffer_info bufK = k.request();
py::buffer_info bufOut = out.request();

// error handling
if (bufIn.ndim != 1 || bufK.ndim != 1 || bufOut.ndim != 1) {
throw std::runtime_error("Number of dimensions must be one");
}

if (bufIn.size < bufK.size) {
throw std::runtime_error("Kernel size must be smaller than input size");
}

// Execute conv
const auto mode = parseConvMode(modeStr);
if (mode == fftconv::ConvMode::Same) {
fftconv::convolve_fftw<T, fftconv::ConvMode::Same>(as_span(a), as_span(k),
as_mutable_span(out));
} else {
fftconv::convolve_fftw<T, fftconv::ConvMode::Full>(as_span(a), as_span(k),
as_mutable_span(out));
}
}

// Same API as np.convolve
template <typename T>
static py::array_t<T> convolve(py::array_t<T> a, py::array_t<T> k,
const std::string &modeStr) {
// Same
py::ssize_t outSize{};

const auto mode = parseConvMode(modeStr);
if (mode == fftconv::ConvMode::Same) {
outSize = a.size();
} else { // Full
outSize = a.size() + k.size() - 1;
}

py::array_t<T> out(outSize);
convolve_(a, k, out, modeStr);
return out;
}

template <typename T>
static void oaconvolve_(py::array_t<T> a, py::array_t<T> k, py::array_t<T> out,
const std::string &modeStr) {
Expand Down Expand Up @@ -73,10 +120,6 @@ static py::array_t<T> oaconvolve(py::array_t<T> a, py::array_t<T> k,
return out;
}

const char *const oaconvolve_doc = R"delimiter(
Performs overlap-add convolution using FFTW. API compatible with np.convolve
)delimiter";

template <typename T>
static void hilbert_(py::array_t<T> a, py::array_t<T> out) {
fftconv::hilbert<T>(as_span(a), as_mutable_span(out));
Expand All @@ -88,6 +131,14 @@ template <typename T> static py::array_t<T> hilbert(py::array_t<T> a) {
return out;
}

const char *const convolve_doc = R"delimiter(
Performs convolution using FFTW. API compatible with np.convolve
)delimiter";

const char *const oaconvolve_doc = R"delimiter(
Performs overlap-add convolution using FFTW. API compatible with np.convolve
)delimiter";

const char *const hilbert_doc = R"delimiter(
Performs envelope detection using the Hilbert transform.
Equivalent to `np.abs(signal.hilbert(a))`
Expand All @@ -97,6 +148,15 @@ PYBIND11_MODULE(_pyfftconv, m) {
m.doc() = "Python wrapper for fftconv";
m.attr("__version__") = FFTCONV_VERSION;

m.def("convolve", convolve<double>, py::arg("a"), py::arg("k"),
py::arg("mode") = "full", convolve_doc);
m.def("convolve", convolve<float>, py::arg("a"), py::arg("k"),
py::arg("mode") = "full", convolve_doc);
m.def("convolve_", convolve_<double>, py::arg("a"), py::arg("k"),
py::arg("out"), py::arg("mode") = "full", convolve_doc);
m.def("convolve_", convolve_<float>, py::arg("a"), py::arg("k"),
py::arg("out"), py::arg("mode") = "full", convolve_doc);

m.def("oaconvolve", oaconvolve<double>, py::arg("a"), py::arg("k"),
py::arg("mode") = "full", oaconvolve_doc);
m.def("oaconvolve", oaconvolve<float>, py::arg("a"), py::arg("k"),
Expand Down
11 changes: 10 additions & 1 deletion py/pyfftconv/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
from ._pyfftconv import __doc__, __version__, oaconvolve, oaconvolve_, hilbert, hilbert_
from ._pyfftconv import (
__doc__,
__version__,
convolve,
convolve_,
oaconvolve,
oaconvolve_,
hilbert,
hilbert_,
)
13 changes: 13 additions & 0 deletions py/pyfftconv/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Python wrapper for fftconv
"""
from __future__ import annotations
from pyfftconv._pyfftconv import convolve
from pyfftconv._pyfftconv import convolve_
from pyfftconv._pyfftconv import hilbert
from pyfftconv._pyfftconv import hilbert_
from pyfftconv._pyfftconv import oaconvolve
from pyfftconv._pyfftconv import oaconvolve_
from . import _pyfftconv
__all__ = ['convolve', 'convolve_', 'hilbert', 'hilbert_', 'oaconvolve', 'oaconvolve_']
__version__: str = '0.5.1'
72 changes: 72 additions & 0 deletions py/pyfftconv/_pyfftconv.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Python wrapper for fftconv
"""
from __future__ import annotations
import numpy
import typing
__all__ = ['convolve', 'convolve_', 'hilbert', 'hilbert_', 'oaconvolve', 'oaconvolve_']
@typing.overload
def convolve(a: numpy.ndarray[numpy.float64], k: numpy.ndarray[numpy.float64], mode: str = 'full') -> numpy.ndarray[numpy.float64]:
"""
Performs convolution using FFTW. API compatible with np.convolve
"""
@typing.overload
def convolve(a: numpy.ndarray[numpy.float32], k: numpy.ndarray[numpy.float32], mode: str = 'full') -> numpy.ndarray[numpy.float32]:
"""
Performs convolution using FFTW. API compatible with np.convolve
"""
@typing.overload
def convolve_(a: numpy.ndarray[numpy.float64], k: numpy.ndarray[numpy.float64], out: numpy.ndarray[numpy.float64], mode: str = 'full') -> None:
"""
Performs convolution using FFTW. API compatible with np.convolve
"""
@typing.overload
def convolve_(a: numpy.ndarray[numpy.float32], k: numpy.ndarray[numpy.float32], out: numpy.ndarray[numpy.float32], mode: str = 'full') -> None:
"""
Performs convolution using FFTW. API compatible with np.convolve
"""
@typing.overload
def hilbert(a: numpy.ndarray[numpy.float64]) -> numpy.ndarray[numpy.float64]:
"""
Performs envelope detection using the Hilbert transform.
Equivalent to `np.abs(signal.hilbert(a))`
"""
@typing.overload
def hilbert(a: numpy.ndarray[numpy.float32]) -> numpy.ndarray[numpy.float32]:
"""
Performs envelope detection using the Hilbert transform.
Equivalent to `np.abs(signal.hilbert(a))`
"""
@typing.overload
def hilbert_(a: numpy.ndarray[numpy.float64], out: numpy.ndarray[numpy.float64]) -> None:
"""
Performs envelope detection using the Hilbert transform.
Equivalent to `np.abs(signal.hilbert(a))`
"""
@typing.overload
def hilbert_(a: numpy.ndarray[numpy.float32], out: numpy.ndarray[numpy.float32]) -> None:
"""
Performs envelope detection using the Hilbert transform.
Equivalent to `np.abs(signal.hilbert(a))`
"""
@typing.overload
def oaconvolve(a: numpy.ndarray[numpy.float64], k: numpy.ndarray[numpy.float64], mode: str = 'full') -> numpy.ndarray[numpy.float64]:
"""
Performs overlap-add convolution using FFTW. API compatible with np.convolve
"""
@typing.overload
def oaconvolve(a: numpy.ndarray[numpy.float32], k: numpy.ndarray[numpy.float32], mode: str = 'full') -> numpy.ndarray[numpy.float32]:
"""
Performs overlap-add convolution using FFTW. API compatible with np.convolve
"""
@typing.overload
def oaconvolve_(a: numpy.ndarray[numpy.float64], k: numpy.ndarray[numpy.float64], out: numpy.ndarray[numpy.float64], mode: str = 'full') -> None:
"""
Performs overlap-add convolution using FFTW. API compatible with np.convolve
"""
@typing.overload
def oaconvolve_(a: numpy.ndarray[numpy.float32], k: numpy.ndarray[numpy.float32], out: numpy.ndarray[numpy.float32], mode: str = 'full') -> None:
"""
Performs overlap-add convolution using FFTW. API compatible with np.convolve
"""
__version__: str = '0.5.1'
59 changes: 0 additions & 59 deletions py/test/test_conv.py

This file was deleted.

Loading

0 comments on commit 26fa393

Please sign in to comment.