Skip to content

Commit

Permalink
refactor: handle audio message decoding in WebSocketStreamingServer a…
Browse files Browse the repository at this point in the history
…nd make WebSocketAudioSource a proxy class
  • Loading branch information
janaab11 committed Jan 3, 2025
1 parent c6a2f92 commit 604d140
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 35 deletions.
32 changes: 0 additions & 32 deletions src/diart/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,38 +201,6 @@ def close(self):
self._mic_stream.close()


class WebSocketAudioSource(AudioSource):
"""Represents a source of audio coming from the network using the WebSocket protocol.
Parameters
----------
sample_rate: int
Sample rate of the chunks emitted.
"""

def __init__(
self,
uri: str,
sample_rate: int,
):
# FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities.
# I would prefer the client to send a JSON with data and sample rate, then resample if needed
super().__init__(uri, sample_rate)

def process_message(self, message: AnyStr):
"""Decode and process an incoming audio message."""
# Send decoded audio to pipeline
self.stream.on_next(utils.decode_audio(message))

def read(self):
"""Starts running the websocket server and listening for audio chunks"""
pass

def close(self):
"""Complete the audio stream for this client."""
self.stream.on_completed()


class TorchStreamAudioSource(AudioSource):
def __init__(
self,
Expand Down
42 changes: 39 additions & 3 deletions src/diart/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from pathlib import Path
from typing import Any, AnyStr, Callable, Dict, Optional, Text, Union

import numpy as np
from websocket_server import WebsocketServer

from . import utils
from . import blocks
from . import sources as src
from .inference import StreamingInference
Expand All @@ -17,11 +19,43 @@
logger = logging.getLogger(__name__)


class ProxyAudioSource(src.AudioSource):
"""Represents a source of audio coming from the network using the WebSocket protocol.
Parameters
----------
sample_rate: int
Sample rate of the chunks emitted.
"""

def __init__(
self,
uri: str,
sample_rate: int,
):
# FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities.
# I would prefer the client to send a JSON with data and sample rate, then resample if needed
super().__init__(uri, sample_rate)

def process_message(self, message: np.ndarray):
"""Process an incoming audio message."""
# Send audio to pipeline
self.stream.on_next(message)

def read(self):
"""Starts running the websocket server and listening for audio chunks"""
pass

def close(self):
"""Complete the audio stream for this client."""
self.stream.on_completed()


@dataclass
class ClientState:
"""Represents the state of a connected client."""

audio_source: src.WebSocketAudioSource
audio_source: ProxyAudioSource
inference: StreamingInference


Expand Down Expand Up @@ -93,7 +127,7 @@ def _create_client_state(self, client_id: Text) -> ClientState:
# This ensures each client has its own state while sharing model weights
pipeline = self.pipeline_class(self.pipeline_config)

audio_source = src.WebSocketAudioSource(
audio_source = ProxyAudioSource(
uri=f"{self.uri}:{client_id}", sample_rate=self.pipeline_config.sample_rate,
)

Expand Down Expand Up @@ -186,7 +220,9 @@ def _on_message_received(
return

try:
self._clients[client_id].audio_source.process_message(message)
# decode message to audio
decoded_audio = utils.decode_audio(message)
self._clients[client_id].audio_source.process_message(decoded_audio)
except (socket.error, ConnectionError) as e:
logger.warning(f"Client {client_id} disconnected: {e}")
# Just cleanup since client is already disconnected
Expand Down

0 comments on commit 604d140

Please sign in to comment.