Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

STT service for hear node #38

Merged
merged 6 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion hri/packages/speech/config/hear.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
hear:
ros__parameters:
STT_SERVER_IP: "127.0.0.1:50051"
STT_SERVER_IP: "127.0.0.1:50051"
START_SERVICE: True
detection_publish_topic: "/keyword_detected"
69 changes: 62 additions & 7 deletions hri/packages/speech/scripts/hear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import rclpy
from rclpy.node import Node
from rclpy.executors import ExternalShutdownException
from rclpy.executors import ExternalShutdownException, MultiThreadedExecutor
from rclpy.callback_groups import MutuallyExclusiveCallbackGroup
import grpc
from frida_interfaces.msg import AudioData
from std_msgs.msg import String
from frida_interfaces.srv import STT
from std_msgs.msg import Bool, String
from speech.speech_api_utils import SpeechApiUtils

import sys
Expand Down Expand Up @@ -37,24 +39,57 @@ def __init__(self):
super().__init__("hear_node")
self.get_logger().info("*Starting Hear Node*")

# Get the gRPC server address from parameters
server_ip = (
self.declare_parameter("STT_SERVER_IP", "127.0.0.1:50051")
.get_parameter_value()
.string_value
)
start_service = (
self.declare_parameter("START_SERVICE", False)
.get_parameter_value()
.bool_value
)

# Initialize the Whisper gRPC client
self.client = WhisperClient(server_ip)

# Create a publisher for the transcriptions
self.transcription_publisher = self.create_publisher(
String, "/speech/raw_command", 10
)

# Create groups for the subscription and service
subscription_group = MutuallyExclusiveCallbackGroup()
service_group = MutuallyExclusiveCallbackGroup()

# Subscribe to audio data
self.audio_subscription = self.create_subscription(
AudioData, "UsefulAudio", self.callback_audio, 10
AudioData,
"UsefulAudio",
self.callback_audio,
10,
callback_group=subscription_group,
)

# Create a service
self.service_active = False
if start_service:
self.service_text = ""
detection_publish_topic = (
self.declare_parameter("detection_publish_topic", "/keyword_detected")
.get_parameter_value()
.string_value
)
self.KWS_publisher_mock = self.create_publisher(
Bool, detection_publish_topic, 10
)
self.stt_service = self.create_service(
STT,
"stt_service",
self.stt_service_callback,
callback_group=service_group,
)

self.get_logger().info("*Hear Node is ready*")

def callback_audio(self, data):
Expand All @@ -75,21 +110,41 @@ def callback_audio(self, data):
# Publish the transcription
msg = String()
msg.data = transcription
self.transcription_publisher.publish(msg)
self.get_logger().info("Transcription published to ROS topic.")

if self.service_active:
# If the service is active, store the transcription
self.service_text = transcription
self.service_active = False
else:
# If the service is not active, publish the transcription
self.transcription_publisher.publish(msg)
self.get_logger().info("Transcription published to ROS topic.")
except grpc.RpcError as e:
self.get_logger().error(f"gRPC error: {e.code()}, {e.details()}")
except Exception as ex:
self.get_logger().error(f"Error during transcription: {str(ex)}")

def stt_service_callback(self, request, response):
self.get_logger().info("Keyword mock service activated, recording audio...")
self.service_active = True
self.KWS_publisher_mock.publish(Bool(data=True))
while self.service_active:
pass
response.text_heard = self.service_text
return response


def main(args=None):
rclpy.init(args=args)
node = HearNode()
executor = MultiThreadedExecutor()
executor.add_node(node)
try:
rclpy.spin(HearNode())
executor.spin()
except (ExternalShutdownException, KeyboardInterrupt):
pass
finally:
node.destroy_node()
rclpy.shutdown()


Expand Down
18 changes: 7 additions & 11 deletions hri/packages/speech/scripts/useful_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def __init__(self):
self.timer = None
self.is_saying = False
self.audio_state = "None"
self.service_active = False

self.publisher = self.create_publisher(AudioData, "UsefulAudio", 20)
self.audio_state_publisher = self.create_publisher(String, "AudioState", 10)
Expand All @@ -89,7 +88,7 @@ def __init__(self):
AudioData, "rawAudioChunk", self.callback_raw_audio, 10
)
self.create_subscription(Bool, "saying", self.callback_saying, 10)
self.create_subscription(Bool, "keyword_detected", self.callback_keyword, 10)
self.create_subscription(Bool, "/keyword_detected", self.callback_keyword, 10)

if not self.use_silero_vad:
self.vad = webrtcvad.Vad()
Expand Down Expand Up @@ -126,10 +125,9 @@ def build_audio(self, data):
self.chunk_count += 1

def discard_audio(self):
if not self.service_active:
self.ring_buffer.clear()
self.voiced_frames = None
self.chunk_count = 0
self.ring_buffer.clear()
self.voiced_frames = None
self.chunk_count = 0

def publish_audio(self):
if self.chunk_count > MIN_CHUNKS_AUDIO_LENGTH:
Expand Down Expand Up @@ -207,8 +205,7 @@ def vad_collector(self, chunk):
):
self.triggered = False
self.compute_audio_state()
if not self.service_active:
self.publish_audio()
self.publish_audio()
self.timer = None

def callback_raw_audio(self, msg):
Expand All @@ -225,9 +222,8 @@ def callback_saying(self, msg):

def callback_keyword(self, msg):
self.triggered = True
if not self.service_active:
self.discard_audio()
self.compute_audio_state()
self.discard_audio()
GilMM27 marked this conversation as resolved.
Show resolved Hide resolved
self.compute_audio_state()

def compute_audio_state(self):
new_state = (
Expand Down
Loading