Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Oct 9, 2024
2 parents bb36be4 + 8cf181f commit ead62fe
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed for Pearson changes inputs ([#2765](https://github.com/Lightning-AI/torchmetrics/pull/2765))


- Fixed bug in `PESQ` metric where `NoUtterancesError` prevented calculating on a batch of data ([#2753](https://github.com/Lightning-AI/torchmetrics/pull/2753))


## [1.4.2] - 2022-09-12

### Added
Expand Down
11 changes: 9 additions & 2 deletions src/torchmetrics/functional/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import numpy as np
import torch
from torch import Tensor
Expand Down Expand Up @@ -83,6 +85,11 @@ def perceptual_evaluation_speech_quality(
)
import pesq as pesq_backend

def _issubtype_number(x: Any) -> bool:
return np.issubdtype(type(x), np.number)

_filter_error_msg = np.vectorize(_issubtype_number)

if fs not in (8000, 16000):
raise ValueError(f"Expected argument `fs` to either be 8000 or 16000 but got {fs}")
if mode not in ("wb", "nb"):
Expand All @@ -103,8 +110,8 @@ def perceptual_evaluation_speech_quality(
pesq_val_np = np.empty(shape=(preds_np.shape[0]))
for b in range(preds_np.shape[0]):
pesq_val_np[b] = pesq_backend.pesq(fs, target_np[b, :], preds_np[b, :], mode)
pesq_val = torch.from_numpy(pesq_val_np)
pesq_val = pesq_val.reshape(preds.shape[:-1])
pesq_val = torch.from_numpy(pesq_val_np[_filter_error_msg(pesq_val_np)].astype(np.float32))
pesq_val = pesq_val.reshape(len(pesq_val))

if keep_same_device:
return pesq_val.to(preds.device)
Expand Down

0 comments on commit ead62fe

Please sign in to comment.