Skip to content

Commit

Permalink
refactor hbf
Browse files Browse the repository at this point in the history
  • Loading branch information
jordens committed Sep 15, 2023
1 parent 396de93 commit 4de5835
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 71 deletions.
162 changes: 103 additions & 59 deletions src/hbf.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/// Filter input items into output items.
pub trait Filter {
/// Input/output item type.
// TODO: impl with generic item type
type Item;

/// Process a block of items.
Expand Down Expand Up @@ -37,24 +38,27 @@ pub trait Filter {

/// Symmetric FIR filter prototype.
///
/// DSP taps 2*M
/// # Generics
/// * `M`: number of taps, one-sided. The filter has effectively 2*M DSP taps
/// * `N`: state size: N = 2*M - 1 + {input/output}.len()
///
/// M: number of taps
/// N: state size: N = 2*M - 1 + {input/output}.len()
///
/// Decimation/interpolation filters
/// # Half band decimation/interpolation filters
///
/// These focus on half-band filters (rate change of 2) and cascades of HBF.
/// Half-band filters (rate change of 2) and cascades of HBFs are implemented in
/// [`HbfDec`] and [`HbfInt`] etc.
/// The half-band filter has unique properties that make it preferrable in many cases:
///
/// * only needs N multiplications (fused multiply accumulate) for 4*N taps
/// * stores less state compared with with a straight FIR
/// * only needs M multiplications (fused multiply accumulate) for 4*M taps
/// * HBF decimator stores less state than a generic FIR filter
/// * as a FIR filter has linear phase/flat group delay
/// * very small passband ripple and excellent stopband attenuation
/// * as a cascade of decimation/interpolation filters, the higher-rate filters
/// need successively fewer taps, allowing the filtering to be dominated by
/// only the highest rate filter with the fewest taps
/// * high dynamic range (compared with a biquad IIR)
/// * In a cascade of HBF the overall latency, group delay, and impulse response
/// length are dominated by the lowest-rate filter which, due to its manageable transition
/// band width (compared to single-stage filters) can be smaller, shorter, and faster.
/// * high dynamic range and inherent stability compared with an IIR filter
/// * can be combined with a CIC filter for non-power-of-two or even higher rate changes
///
/// The implementations here are all `no_std` and `no-alloc`.
Expand All @@ -71,12 +75,21 @@ pub struct SymFir<'a, const M: usize, const N: usize> {
}

impl<'a, const M: usize, const N: usize> SymFir<'a, M, N> {
/// taps: one-sided, expluding center tap, oldest to one-before-center
/// Create a new `SymFir`.
///
/// # Args
/// * `taps`: one-sided FIR coefficients, expluding center tap, oldest to one-before-center
pub fn new(taps: &'a [f32; M]) -> Self {
debug_assert!(N >= M * 2);
Self { x: [0.0; N], taps }
}

/// Obtain a mutable reference to the input items buffer space.
#[inline]
pub fn buf_mut(&mut self) -> &mut [f32] {
&mut self.x[2 * M - 1..]
}

/// Perform the FIR convolution and yield results iteratively.
#[inline]
pub fn get(&self) -> impl Iterator<Item = f32> + '_ {
Expand All @@ -89,6 +102,15 @@ impl<'a, const M: usize, const N: usize> SymFir<'a, M, N> {
.sum()
})
}

/// Move items as new filter state.
///
/// # Args
/// * `offset`: Keep the `2*M-1` items at `offset` as the new filter state.
#[inline]
pub fn keep_state(&mut self, offset: usize) {
self.x.copy_within(offset..offset + 2 * M - 1, 0);
}
}

// TODO: pub struct SymFirInt<R>, SymFirDec<R>
Expand All @@ -106,8 +128,11 @@ pub struct HbfDec<'a, const M: usize, const N: usize> {
}

impl<'a, const M: usize, const N: usize> HbfDec<'a, M, N> {
/// Non-zero (odd) taps from oldest to one-before-center.
/// Normalized such that center tap is 1.
/// Create a new `HbfDec`.
///
/// # Args
/// * `taps`: The FIR filter coefficients. Only the non-zero (odd) taps
/// from oldest to one-before-center. Normalized such that center tap is 1.
pub fn new(taps: &'a [f32; M]) -> Self {
Self {
even: [0.0; N],
Expand Down Expand Up @@ -141,7 +166,7 @@ impl<'a, const M: usize, const N: usize> Filter for HbfDec<'a, M, N> {
for (xi, (even, odd)) in x.chunks_exact(2).zip(
self.even[M - 1..][..k]
.iter_mut()
.zip(self.odd.x[2 * M - 1..][..k].iter_mut()),
.zip(self.odd.buf_mut()[..k].iter_mut()),
) {
*even = xi[0];
*odd = xi[1];
Expand All @@ -155,7 +180,7 @@ impl<'a, const M: usize, const N: usize> Filter for HbfDec<'a, M, N> {
}
// keep state
self.even.copy_within(k..k + M - 1, 0);
self.odd.x.copy_within(k..k + 2 * M - 1, 0);
self.odd.keep_state(k);
&mut y[..k]
}
}
Expand All @@ -179,6 +204,11 @@ impl<'a, const M: usize, const N: usize> HbfInt<'a, M, N> {
fir: SymFir::new(taps),
}
}

/// Obtain a mutable reference to the input items buffer space
pub fn buf_mut(&mut self) -> &mut [f32] {
self.fir.buf_mut()
}
}

impl<'a, const M: usize, const N: usize> Filter for HbfInt<'a, M, N> {
Expand All @@ -203,17 +233,20 @@ impl<'a, const M: usize, const N: usize> Filter for HbfInt<'a, M, N> {
let k = y.len() / 2;
let x = x.unwrap_or(&y[..k]);
// load input
self.fir.x[2 * M - 1..][..k].copy_from_slice(x);
self.fir.buf_mut()[..k].copy_from_slice(x);
// compute output
for (yi, (even, &odd)) in y
.chunks_exact_mut(2)
.zip(self.fir.get().zip(self.fir.x[M..][..k].iter()))
{
yi[0] = even;
yi[1] = odd;
// Choose the even item to be the interpolated one.
// The alternative would have the same response length
// but larger latency.
yi[0] = even; // interpolated
yi[1] = odd; // center tap: identity
}
// keep state
self.fir.x.copy_within(k..k + 2 * M - 1, 0);
self.fir.keep_state(k);
y
}
}
Expand Down Expand Up @@ -277,7 +310,7 @@ pub const HBF_CASCADE_BLOCK: usize = 1 << 6;
/// Supports rate changes of 1, 2, 4, 8, and 16.
#[derive(Copy, Clone, Debug)]
pub struct HbfDecCascade {
n: usize,
depth: usize,
stages: (
HbfDec<'static, { HBF_TAPS.0.len() }, { 2 * HBF_TAPS.0.len() - 1 + HBF_CASCADE_BLOCK }>,
HbfDec<'static, { HBF_TAPS.1.len() }, { 2 * HBF_TAPS.1.len() - 1 + HBF_CASCADE_BLOCK * 2 }>,
Expand All @@ -289,7 +322,7 @@ pub struct HbfDecCascade {
impl Default for HbfDecCascade {
fn default() -> Self {
Self {
n: 0,
depth: 0,
stages: (
HbfDec::new(&HBF_TAPS.0),
HbfDec::new(&HBF_TAPS.1),
Expand All @@ -302,14 +335,14 @@ impl Default for HbfDecCascade {

impl HbfDecCascade {
#[inline]
pub fn set_n(&mut self, n: usize) {
pub fn set_depth(&mut self, n: usize) {
assert!(n <= 4);
self.n = n;
self.depth = n;
}

#[inline]
pub fn n(&self) -> usize {
self.n
pub fn depth(&self) -> usize {
self.depth
}
}

Expand All @@ -319,8 +352,8 @@ impl Filter for HbfDecCascade {
#[inline]
fn block_size(&self) -> (usize, usize) {
(
1 << self.n,
match self.n {
1 << self.depth,
match self.depth {
0 => usize::MAX,
1 => self.stages.0.block_size().1,
2 => self.stages.1.block_size().1,
Expand All @@ -333,16 +366,16 @@ impl Filter for HbfDecCascade {
#[inline]
fn response_length(&self) -> usize {
let mut n = 0;
if self.n > 3 {
if self.depth > 3 {
n = n / 2 + self.stages.3.response_length();
}
if self.n > 2 {
if self.depth > 2 {
n = n / 2 + self.stages.2.response_length();
}
if self.n > 1 {
if self.depth > 1 {
n = n / 2 + self.stages.1.response_length();
}
if self.n > 0 {
if self.depth > 0 {
n = n / 2 + self.stages.0.response_length();
}
n
Expand All @@ -358,31 +391,35 @@ impl Filter for HbfDecCascade {
}
let n = y.len();

if self.n > 3 {
if self.depth > 3 {
y = self.stages.3.process_block(None, y);
}
if self.n > 2 {
if self.depth > 2 {
y = self.stages.2.process_block(None, y);
}
if self.n > 1 {
if self.depth > 1 {
y = self.stages.1.process_block(None, y);
}
if self.n > 0 {
if self.depth > 0 {
y = self.stages.0.process_block(None, y);
}
debug_assert_eq!(y.len(), n >> self.n);
debug_assert_eq!(y.len(), n >> self.depth);
y
}
}

/// Half-band interpolation filter cascade with optimal taps.
///
/// This is a no_alloc version without trait objects.
/// The price to pay is fixed and maximal memory usage independent
/// of block size and cascade length.
///
/// See [HBF_TAPS].
/// Only in-place processing is implemented.
/// Supports rate changes of 1, 2, 4, 8, and 16.
#[derive(Copy, Clone, Debug)]
pub struct HbfIntCascade {
n: usize,
depth: usize,
pub stages: (
HbfInt<'static, { HBF_TAPS.0.len() }, { 2 * HBF_TAPS.0.len() - 1 + HBF_CASCADE_BLOCK }>,
HbfInt<'static, { HBF_TAPS.1.len() }, { 2 * HBF_TAPS.1.len() - 1 + HBF_CASCADE_BLOCK * 2 }>,
Expand All @@ -394,7 +431,7 @@ pub struct HbfIntCascade {
impl Default for HbfIntCascade {
fn default() -> Self {
Self {
n: 4,
depth: 4,
stages: (
HbfInt::new(&HBF_TAPS.0),
HbfInt::new(&HBF_TAPS.1),
Expand All @@ -406,13 +443,13 @@ impl Default for HbfIntCascade {
}

impl HbfIntCascade {
pub fn set_n(&mut self, n: usize) {
pub fn set_depth(&mut self, n: usize) {
assert!(n <= 4);
self.n = n;
self.depth = n;
}

pub fn n(&self) -> usize {
self.n
pub fn depth(&self) -> usize {
self.depth
}
}

Expand All @@ -422,8 +459,8 @@ impl Filter for HbfIntCascade {
#[inline]
fn block_size(&self) -> (usize, usize) {
(
1 << self.n,
match self.n {
1 << self.depth,
match self.depth {
0 => usize::MAX,
1 => self.stages.0.block_size().1,
2 => self.stages.1.block_size().1,
Expand All @@ -436,16 +473,16 @@ impl Filter for HbfIntCascade {
#[inline]
fn response_length(&self) -> usize {
let mut n = 0;
if self.n > 0 {
if self.depth > 0 {
n = 2 * n + self.stages.0.response_length();
}
if self.n > 1 {
if self.depth > 1 {
n = 2 * n + self.stages.1.response_length();
}
if self.n > 2 {
if self.depth > 2 {
n = 2 * n + self.stages.2.response_length();
}
if self.n > 3 {
if self.depth > 3 {
n = 2 * n + self.stages.3.response_length();
}
n
Expand All @@ -459,18 +496,19 @@ impl Filter for HbfIntCascade {
if x.is_some() {
unimplemented!(); // TODO: one intermediate buffer and `y`
}
// TODO: use buf_mut() and write directly into next filters' input buffer

let mut n = y.len() >> self.n;
if self.n > 0 {
let mut n = y.len() >> self.depth;
if self.depth > 0 {
n = self.stages.0.process_block(None, &mut y[..2 * n]).len();
}
if self.n > 1 {
if self.depth > 1 {
n = self.stages.1.process_block(None, &mut y[..2 * n]).len();
}
if self.n > 2 {
if self.depth > 2 {
n = self.stages.2.process_block(None, &mut y[..2 * n]).len();
}
if self.n > 3 {
if self.depth > 3 {
n = self.stages.3.process_block(None, &mut y[..2 * n]).len();
}
debug_assert_eq!(n, y.len());
Expand Down Expand Up @@ -503,18 +541,24 @@ mod test {
#[test]
fn decim() {
let mut h = HbfDecCascade::default();
h.set_n(4);
assert_eq!(h.block_size(), (1 << h.n(), HBF_CASCADE_BLOCK << h.n()));
let mut x: Vec<_> = (0..2 << h.n()).map(|i| i as f32).collect();
h.set_depth(4);
assert_eq!(
h.block_size(),
(1 << h.depth(), HBF_CASCADE_BLOCK << h.depth())
);
let mut x: Vec<_> = (0..2 << h.depth()).map(|i| i as f32).collect();
let x = h.process_block(None, &mut x);
println!("{:?}", x);
}

#[test]
fn interp() {
let mut h = HbfIntCascade::default();
h.set_n(4);
assert_eq!(h.block_size(), (1 << h.n(), HBF_CASCADE_BLOCK << h.n()));
h.set_depth(4);
assert_eq!(
h.block_size(),
(1 << h.depth(), HBF_CASCADE_BLOCK << h.depth())
);
let k = h.block_size().0;
let r = h.response_length();
let mut x = vec![0.0; (r + 1 + k - 1) / k * k];
Expand All @@ -524,7 +568,7 @@ mod test {
assert!(x[r] != 0.0);
assert_eq!(x[r + 1..], vec![0.0; x.len() - r - 1]);

let g = (1 << h.n()) as f32;
let g = (1 << h.depth()) as f32;
let mut y = Vec::from_iter(x.iter().map(|&x| Complex { re: x / g, im: 0.0 }));
// pad
y.resize(5 << 10, Complex::default());
Expand Down Expand Up @@ -579,7 +623,7 @@ mod test {
fn insn_casc() {
let mut x = [9.0; 1 << 10];
let mut h = HbfDecCascade::default();
h.set_n(4);
h.set_depth(4);
for _ in 0..1 << 20 {
h.process_block(None, &mut x);
}
Expand Down
Loading

0 comments on commit 4de5835

Please sign in to comment.