Skip to content

Commit

Permalink
asyncio: Add background_task decorator
Browse files Browse the repository at this point in the history
This should make calling async functions from sync ones more natural.
  • Loading branch information
rafaelmardojai committed Sep 8, 2024
1 parent a4442d0 commit 7071f69
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 78 deletions.
28 changes: 23 additions & 5 deletions dialect/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import contextlib
from typing import Coroutine
import functools
from typing import Callable, Coroutine

from gi.events import GLibEventLoopPolicy
from gi.repository import GLib


@contextlib.contextmanager
Expand All @@ -21,12 +21,30 @@ def glib_event_loop_policy():


def create_background_task(coro: Coroutine) -> asyncio.Task:
"""Create and track a task.
"""
Create and track a task.
Normally tasks are weak-referenced by asyncio.
We keep track of them, so they can be completed
before GC kicks in.
We keep track of them, so they can be completed before GC kicks in.
"""
task = asyncio.create_task(coro)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
return task


def background_task(f: Callable[..., Coroutine]):
"""
Wraps an async function to be run using ``create_background_task``.
Useful to use async functions like signal handlers or GTK template callbacks.
Note: The return value will be lost, so this is not suitable when you need to
return something from the coroutine, what might be needed in some signal handlers.
"""

@functools.wraps(f)
def decor(*args, **kwargs):
create_background_task(f(*args, **kwargs))

return decor
134 changes: 61 additions & 73 deletions dialect/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from gi.repository import Adw, Gdk, Gio, GLib, GObject, Gst, Gtk

from dialect.asyncio import create_background_task
from dialect.asyncio import background_task
from dialect.define import APP_ID, PROFILE, RES_PATH, TRANS_NUMBER
from dialect.languages import LanguagesListModel
from dialect.providers import (
Expand Down Expand Up @@ -227,9 +227,9 @@ def setup(self):
self.set_help_overlay(DialectShortcutsWindow())

# Load translator
create_background_task(self.load_translator())
self.load_translator()
# Load text to speech
create_background_task(self.load_tts())
self.load_tts()

# Listen to active providers changes
Settings.get().connect("provider-changed", self._on_active_provider_changed)
Expand Down Expand Up @@ -287,10 +287,11 @@ def setup_translation(self):
def reload_provider(self, kind: str):
match kind:
case "translator":
create_background_task(self.load_translator())
self.load_translator()
case "tts":
create_background_task(self.load_tts())
self.load_tts()

@background_task
async def load_translator(self):
self.translator_loading = True

Expand Down Expand Up @@ -441,6 +442,7 @@ def show_translator_api_key_view(self, required=False):

self.main_stack.props.visible_child_name = "api-key"

@background_task
async def load_tts(self):
self.src_speech_btn.loading()
self.dest_speech_btn.loading()
Expand Down Expand Up @@ -760,10 +762,8 @@ def _on_suggest_action(self, *_args):
)
self.dest_text.props.editable = True

def _on_suggest_ok_action(self, *_args):
create_background_task(self._send_suggestion())

async def _send_suggestion(self):
@background_task
async def _on_suggest_ok_action(self, *_args):
if not self.provider["trans"]:
return

Expand Down Expand Up @@ -802,7 +802,7 @@ def _on_src_listen_action(self, *_args):

src_text = self.src_buffer.get_text(self.src_buffer.get_start_iter(), self.src_buffer.get_end_iter(), True)
src_language = self.src_lang_selector.selected
self._pre_speech(src_text, src_language, "src")
self._on_speech(src_text, src_language, "src")

def _on_dest_listen_action(self, *_args):
if self.current_speech:
Expand All @@ -811,47 +811,36 @@ def _on_dest_listen_action(self, *_args):

dest_text = self.dest_buffer.get_text(self.dest_buffer.get_start_iter(), self.dest_buffer.get_end_iter(), True)
dest_language = self.dest_lang_selector.selected
self._pre_speech(dest_text, dest_language, "dest")
self._on_speech(dest_text, dest_language, "dest")

def _pre_speech(self, text: str, lang: str, called_from: Literal["src", "dest"]):
@background_task
async def _on_speech(self, text: str, lang: str, called_from: Literal["src", "dest"]):
# Retry loading TTS provider
if self.speech_provider_failed:
create_background_task(self.load_tts())
self.load_tts()
return

if text != "":
self.speech_loading = True
self.current_speech = {"text": text, "lang": lang, "called_from": called_from}
self._check_speech_enabled()
create_background_task(self._download_speech())

if called_from == "src": # Show spinner on button
self.src_speech_btn.loading()
else:
self.dest_speech_btn.loading()

def _speech_reset(self, set_ready: bool = True):
if not self.player:
if not text or not self.provider["tts"] or not self.player:
return

self.player.set_state(Gst.State.NULL)
self.current_speech = None
self.speech_loading = False
# Set loading state and current speech to update UI
self.speech_loading = True
self.current_speech = {"text": text, "lang": lang, "called_from": called_from}
self._check_speech_enabled()

if set_ready:
self.src_speech_btn.ready()
self.dest_speech_btn.ready()

async def _download_speech(self):
if not self.provider["tts"] or not self.current_speech:
return
if called_from == "src": # Show spinner on button
self.src_speech_btn.loading()
else:
self.dest_speech_btn.loading()

# Download speech
try:
speech_file = await self.provider["tts"].speech(self.current_speech["text"], self.current_speech["lang"])

self._play_audio(speech_file.name)
speech_file.close()
file_ = await self.provider["tts"].speech(self.current_speech["text"], self.current_speech["lang"])
uri = "file://" + file_.name
self.player.set_property("uri", uri)
self.player.set_state(Gst.State.PLAYING)
self.add_tick_callback(self._gst_progress_timeout)
file_.close()

except (RequestError, ProviderError) as exc:
logging.error(exc)
Expand All @@ -878,14 +867,18 @@ async def _download_speech(self):
self.send_notification(text, action=action)
self._speech_reset(False)

def _play_audio(self, path: str):
def _speech_reset(self, set_ready: bool = True):
if not self.player:
return

uri = "file://" + path
self.player.set_property("uri", uri)
self.player.set_state(Gst.State.PLAYING)
self.add_tick_callback(self._gst_progress_timeout)
self.player.set_state(Gst.State.NULL)
self.current_speech = None
self.speech_loading = False
self._check_speech_enabled()

if set_ready:
self.src_speech_btn.ready()
self.dest_speech_btn.ready()

def _on_gst_message(self, _bus, message: Gst.Message):
if message.type == Gst.MessageType.EOS or message.type == Gst.MessageType.ERROR:
Expand Down Expand Up @@ -1077,36 +1070,13 @@ def _on_mistakes_clicked(self, *_args):
return Gdk.EVENT_STOP

@Gtk.Template.Callback()
def _on_translation(self, *_args):
if not self.provider["trans"]:
return

# If it's like the last translation then it's useless to continue
if not self._appeared_before():
create_background_task(self._translation())

def _appeared_before(self):
if not self.provider["trans"]:
return

src_language = self.src_lang_selector.selected
dest_language = self.dest_lang_selector.selected
src_text = self.src_buffer.get_text(self.src_buffer.get_start_iter(), self.src_buffer.get_end_iter(), True)
translation = self.current_translation
if (
len(self.provider["trans"].history) >= self.current_history + 1
and translation
and (translation.original.src == src_language or "auto")
and translation.original.dest == dest_language
and translation.original.text == src_text
):
return True
return False

async def _translation(self):
if not self.provider["trans"]:
@background_task
async def _on_translation(self, *_args):
if not self.provider["trans"] or self._appeared_before():
# If it's like the last translation then it's useless to continue
return

# Run translation
if self.next_translation:
request = self.next_translation
self.next_translation = None
Expand Down Expand Up @@ -1203,6 +1173,24 @@ async def _translation(self):
if not self.translation_loading:
self._translation_finish()

def _appeared_before(self):
if not self.provider["trans"]:
return

src_language = self.src_lang_selector.selected
dest_language = self.dest_lang_selector.selected
src_text = self.src_buffer.get_text(self.src_buffer.get_start_iter(), self.src_buffer.get_end_iter(), True)
translation = self.current_translation
if (
len(self.provider["trans"].history) >= self.current_history + 1
and translation
and (translation.original.src == src_language or "auto")
and translation.original.dest == dest_language
and translation.original.text == src_text
):
return True
return False

def _translation_finish(self):
self.trans_spinner.hide()
self.dest_box.props.sensitive = True
Expand Down

0 comments on commit 7071f69

Please sign in to comment.