Skip to content

Commit

Permalink
dev(narugo): add unittest for sp metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Aug 27, 2024
1 parent 044f77a commit cabdc04
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 80 deletions.
36 changes: 34 additions & 2 deletions soundutils/data/sound.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import os
from datetime import datetime, timedelta
from typing import Tuple, Optional, Union

import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import soundfile
from hbutils.string import plural_word
from matplotlib.ticker import FuncFormatter
from scipy import signal

SoundTyping = Union[str, os.PathLike, 'Sound']
Expand Down Expand Up @@ -98,14 +102,16 @@ def from_numpy(cls, data: np.ndarray, sample_rate: int) -> 'Sound':
return cls(data, sample_rate)

def to_numpy(self) -> Tuple[np.ndarray, int]:
# dump data to numpy format
# it has been 100% aligned with torchaudio's loading result
return self._to_numpy().T, self._sample_rate

@classmethod
def open(cls, sound_file: str) -> 'Sound':
def open(cls, sound_file: Union[str, os.PathLike]) -> 'Sound':
data, sample_rate = soundfile.read(sound_file)
return cls(data, sample_rate)

def save(self, sound_file: str):
def save(self, sound_file: Union[str, os.PathLike]):
soundfile.write(sound_file, self._data, self._sample_rate)

@classmethod
Expand All @@ -116,3 +122,29 @@ def load(cls, sound: SoundTyping) -> 'Sound':
return cls.open(sound)
else:
raise TypeError(f'Unknown sound type - {sound!r}.')

def plot(self, ax=None, title: Optional[str] = None):
times = np.arange(self.samples) / float(self._sample_rate)
base_time = datetime(1970, 1, 1)
times = [base_time + timedelta(seconds=t) for t in times]
times = mdates.date2num(times)

ax = ax or plt.gca()
data = self._to_numpy()
for cid in range(self.channels):
ax.plot(times, data[:, cid], label=f'Channel #{cid}', alpha=0.5)

def _fmt_time(x, pos):
dt, _ = mdates.num2date(x), pos
return dt.strftime('%H:%M:%S') + f'.{int(dt.microsecond / 1000):03d}'

ax.xaxis.set_major_formatter(FuncFormatter(_fmt_time))
locator = mdates.AutoDateLocator(minticks=5, maxticks=10)
ax.xaxis.set_major_locator(locator)

ax.set_xlabel('Time [hh:mm:ss.mmm]')
ax.set_ylabel('Amplitude')
ax.set_title(f'{title or "Audio Signal"}\n'
f'Channels: {self.channels}, Sample Rate: {self._sample_rate}\n'
f'Time: {self.time:.3f}s ({plural_word(self.samples, "frame")})\n')
ax.legend()
7 changes: 4 additions & 3 deletions soundutils/similarity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base import SoundAlignError, SoundLengthNotMatch, SoundResampleRateNotMatch, SoundChannelsNotMatch
from .correlation import sound_correlation
from .cosine import sound_cosine_similarity
from .euclidean import sound_euclidean
from .correlation import sound_pearson_similarity
from .dtw import sound_fastdtw
from .mse import sound_mse, sound_rmse
from .spectral import sound_spectral_centroid_distance
51 changes: 40 additions & 11 deletions soundutils/similarity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class SoundLengthNotMatch(SoundAlignError):
def _align_sounds(
sound1: SoundTyping, sound2: SoundTyping,
resample_rate_align: Literal['max', 'min', 'none'] = 'none',
time_align: Literal['pad', 'resample', 'none'] = 'none',
channels_align: Literal['none', 'noncheck'] = 'none',
) -> Tuple[np.ndarray, np.ndarray]:
time_align: Literal['none', 'noncheck', 'pad', 'prefix', 'resample_max', 'resample_min'] = 'none',
channels_align: Literal['none'] = 'none',
) -> Tuple[Tuple[np.ndarray, int], Tuple[np.ndarray, int]]:
sound1, sound2 = Sound.load(sound1), Sound.load(sound2)
if channels_align == 'none':
if sound1.channels != sound2.channels:
Expand Down Expand Up @@ -66,11 +66,40 @@ def _align_sounds(
f'{data1.shape[-1] / sr1:.3f}s ({plural_word(data1.shape[-1], "frame")}) vs '
f'{data2.shape[-1] / sr2:.3f}s ({plural_word(data2.shape[-1], "frame")}).')
else:
# shape of data1 and data2: (channels, frames)
# TODO: support 3 modes of time_align:
# * 'pad', pad the sound data with fewer frames with all 0 constants
# * 'resample_max', resample the shorter sound data to the longer one's frames
# * 'resample_min', resample the longer sound data to the shorter one's frames
raise NotImplementedError

return data1, data2
if time_align == 'pad':
# Pad the shorter sound with zeros
max_frames = max(data1.shape[-1], data2.shape[-1])
if data1.shape[-1] < max_frames:
pad_width = ((0, 0), (0, max_frames - data1.shape[-1]))
data1 = np.pad(data1, pad_width, mode='constant')
elif data2.shape[-1] < max_frames:
pad_width = ((0, 0), (0, max_frames - data2.shape[-1]))
data2 = np.pad(data2, pad_width, mode='constant')

elif time_align == 'prefix':
# Crop the longer sound's prefix
min_frames = min(data1.shape[-1], data2.shape[-1])
data1 = data1[:, :min_frames]
data2 = data2[:, :min_frames]

elif time_align == 'resample_max':
# Resample the shorter sound to match the longer one
max_frames = max(data1.shape[-1], data2.shape[-1])
if data1.shape[-1] < max_frames:
data1 = resample(data1, max_frames, axis=-1)
elif data2.shape[-1] < max_frames:
data2 = resample(data2, max_frames, axis=-1)

elif time_align == 'resample_min':
# Resample the longer sound to match the shorter one
min_frames = min(data1.shape[-1], data2.shape[-1])
if data1.shape[-1] > min_frames:
data1 = resample(data1, min_frames, axis=-1)
elif data2.shape[-1] > min_frames:
data2 = resample(data2, min_frames, axis=-1)

else:
raise ValueError(f'Invalid time align mode - {time_align!r}.')

# shape: (channels, frames)
return (data1, sr1), (data2, sr2)
6 changes: 3 additions & 3 deletions soundutils/similarity/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from .base import _align_sounds


def sound_correlation(
def sound_pearson_similarity(
sound1: SoundTyping, sound2: SoundTyping,
resample_rate_align: Literal['max', 'min', 'none'] = 'none',
time_align: Literal['pad', 'resample', 'none'] = 'none',
time_align: Literal['none', 'pad', 'prefix', 'resample_max', 'resample_min'] = 'none',
channels_align: Literal['none'] = 'none',
) -> float:
data1, data2 = _align_sounds(
(data1, sr1), (data2, sr2) = _align_sounds(
sound1=sound1,
sound2=sound2,
resample_rate_align=resample_rate_align,
Expand Down
27 changes: 0 additions & 27 deletions soundutils/similarity/cosine.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
from typing import Literal

import numpy as np
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean

from .base import _align_sounds
from ..data import SoundTyping


def sound_euclidean(
def sound_fastdtw(
sound1: SoundTyping, sound2: SoundTyping,
resample_rate_align: Literal['max', 'min', 'none'] = 'none',
time_align: Literal['pad', 'resample', 'none'] = 'none',
time_align: Literal['noncheck', 'pad', 'prefix', 'resample_max', 'resample_min'] = 'noncheck',
channels_align: Literal['none'] = 'none',
radius: int = 1,
) -> float:
data1, data2 = _align_sounds(
(data1, sr1), (data2, sr2) = _align_sounds(
sound1=sound1,
sound2=sound2,
resample_rate_align=resample_rate_align,
time_align=time_align,
channels_align=channels_align,
)

# Euclidean distance
euclidean_distance = np.mean(
[euclidean(data1[i, :], data2[i, :]) for i in range(data1.shape[0])])
return euclidean_distance.item()
return fastdtw(
data1.T, data2.T,
radius=radius,
dist=euclidean,
)
41 changes: 41 additions & 0 deletions soundutils/similarity/mse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Literal

import numpy as np

from .base import _align_sounds
from ..data import SoundTyping


def sound_mse(
sound1: SoundTyping, sound2: SoundTyping,
resample_rate_align: Literal['max', 'min', 'none'] = 'none',
time_align: Literal['none', 'pad', 'prefix', 'resample_max', 'resample_min'] = 'none',
channels_align: Literal['none'] = 'none',
p: float = 2.0
) -> float:
(data1, sr1), (data2, sr2) = _align_sounds(
sound1=sound1,
sound2=sound2,
resample_rate_align=resample_rate_align,
time_align=time_align,
channels_align=channels_align,
)

return np.mean((data1 - data2) ** p).item()


def sound_rmse(
sound1: SoundTyping, sound2: SoundTyping,
resample_rate_align: Literal['max', 'min', 'none'] = 'none',
time_align: Literal['none', 'pad', 'prefix', 'resample_max', 'resample_min'] = 'none',
channels_align: Literal['none'] = 'none',
p: float = 2.0,
) -> float:
(data1, sr1), (data2, sr2) = _align_sounds(
sound1=sound1,
sound2=sound2,
resample_rate_align=resample_rate_align,
time_align=time_align,
channels_align=channels_align,
)
return (np.mean((data1 - data2) ** p) ** (1.0 / p)).item()
43 changes: 43 additions & 0 deletions soundutils/similarity/spectral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Literal, Optional

import numpy as np
from scipy.signal import spectrogram

from .base import _align_sounds
from ..data import SoundTyping


def sound_spectral_centroid_distance(
sound1: SoundTyping, sound2: SoundTyping,
resample_rate_align: Literal['max', 'min', 'none'] = 'none',
time_align: Literal['none', 'pad', 'prefix', 'resample_max', 'resample_min'] = 'none',
channels_align: Literal['none'] = 'none',
eps: Optional[float] = None
) -> float:
(data1, sr1), (data2, sr2) = _align_sounds(
sound1=sound1,
sound2=sound2,
resample_rate_align=resample_rate_align,
time_align=time_align,
channels_align=channels_align,
)

assert sr1 == sr2, 'Sample rate not match and not aligned, this must be a bug.'
sr = sr1

channels = data1.shape[0]
distances = []
eps = eps if eps is not None else np.finfo(data1.dtype).eps
for ch in range(channels):
_, _, Sxx1 = spectrogram(data1[ch], sr)
_, _, Sxx2 = spectrogram(data2[ch], sr)

Sxx1 += eps
Sxx2 += eps

centroid1 = np.sum(Sxx1 * np.arange(Sxx1.shape[0])[:, np.newaxis], axis=0) / np.sum(Sxx1, axis=0)
centroid2 = np.sum(Sxx2 * np.arange(Sxx2.shape[0])[:, np.newaxis], axis=0) / np.sum(Sxx2, axis=0)

distances.append(np.mean(np.abs(centroid1 - centroid2)))

return np.mean(distances).item()
25 changes: 0 additions & 25 deletions test/similarity/test_euclidean.py

This file was deleted.

Loading

0 comments on commit cabdc04

Please sign in to comment.