diff --git a/ovos_plugin_manager/templates/language.py b/ovos_plugin_manager/templates/language.py index 34efdc51..743674ff 100644 --- a/ovos_plugin_manager/templates/language.py +++ b/ovos_plugin_manager/templates/language.py @@ -1,75 +1,88 @@ +import abc + from ovos_config.config import Configuration from ovos_utils import classproperty from ovos_utils.process_utils import RuntimeRequirements +from typing import Optional, Dict, Union, List, Set class LanguageDetector: - def __init__(self, config=None): + def __init__(self, config: Optional[Dict[str, Union[str, int]]] = None): + """ + Initialize the LanguageDetector with configuration settings. + + Args: + config (Optional[Dict[str, Union[str, int]]]): Configuration dictionary. + Can contain "lang" for default language, "hint_lang" for a hint language, and "boost" for language boost score. + """ self.config = config or {} - self.default_language = self.config.get("lang") or "en-us" - # hint_language: str E.g., 'it' boosts Italian - self.hint_language = self.config.get("hint_lang") or \ - self.config.get('user') or self.default_language - # boost score for this language + self.default_language = self.config.get("lang", "en-us") + self.hint_language = self.config.get("hint_lang") or self.config.get('user') or self.default_language self.boost = self.config.get("boost") @classproperty - def runtime_requirements(self): - """ skill developers should override this if they do not require connectivity - some examples: - IOT plugin that controls devices via LAN could return: - scans_on_init = True - RuntimeRequirements(internet_before_load=False, - network_before_load=scans_on_init, - requires_internet=False, - requires_network=True, - no_internet_fallback=True, - no_network_fallback=False) - online search plugin with a local cache: - has_cache = False - RuntimeRequirements(internet_before_load=not has_cache, - network_before_load=not has_cache, - requires_internet=True, - requires_network=True, - no_internet_fallback=True, - no_network_fallback=True) - a fully offline plugin: - RuntimeRequirements(internet_before_load=False, - network_before_load=False, - requires_internet=False, - requires_network=False, - no_internet_fallback=True, - no_network_fallback=True) - """ - return RuntimeRequirements(internet_before_load=False, - network_before_load=False, - requires_internet=False, - requires_network=False, - no_internet_fallback=True, - no_network_fallback=True) - - def detect(self, text): - # assume default language - return self.default_language - - def detect_probs(self, text): - return {self.detect(text): 1} - - @property - def available_languages(self) -> set: + def runtime_requirements(self) -> RuntimeRequirements: + """ + Define the runtime requirements for this language detector. + + Returns: + RuntimeRequirements: Object indicating the runtime needs, including internet and network requirements. + """ + return RuntimeRequirements( + internet_before_load=False, + network_before_load=False, + requires_internet=False, + requires_network=False, + no_internet_fallback=True, + no_network_fallback=True + ) + + @abc.abstractmethod + def detect(self, text: str) -> str: + """ + Detect the language of the given text. + + Args: + text (str): The text to detect the language of. + + Returns: + str: The detected language code (e.g., 'en-us'). + """ + + @abc.abstractmethod + def detect_probs(self, text: str) -> Dict[str, float]: + """ + Detect the language of the text and return probabilities. + + Args: + text (str): The text to detect the language of. + + Returns: + Dict[str, float]: A dictionary with the detected language as the key and its probability as the value. + """ + + @property # TODO - make abstract method in future releases (mandatory for plugins to implement) + def available_languages(self) -> Set[str]: """ Return languages supported by this detector implementation in this state. This should be a set of languages this detector is capable of recognizing. This property should be overridden by the derived class to advertise what languages that engine supports. Returns: - set: supported languages + Set[str]: A set of language codes supported by this detector. """ return set() class LanguageTranslator: - def __init__(self, config=None): + def __init__(self, config: Optional[Dict[str, str]] = None): + """ + Initialize the LanguageTranslator with configuration settings. + + Args: + config (Optional[Dict[str, str]]): Configuration dictionary. + Can contain "lang" for the default language and "internal" for the internal language. + """ self.config = config or {} # translate from, unless specified/detected otherwise self.default_language = self.config.get("lang") or "en-us" @@ -79,44 +92,48 @@ def __init__(self, config=None): self.default_language @classproperty - def runtime_requirements(self): - """ skill developers should override this if they do not require connectivity - some examples: - IOT plugin that controls devices via LAN could return: - scans_on_init = True - RuntimeRequirements(internet_before_load=False, - network_before_load=scans_on_init, - requires_internet=False, - requires_network=True, - no_internet_fallback=True, - no_network_fallback=False) - online search plugin with a local cache: - has_cache = False - RuntimeRequirements(internet_before_load=not has_cache, - network_before_load=not has_cache, - requires_internet=True, - requires_network=True, - no_internet_fallback=True, - no_network_fallback=True) - a fully offline plugin: - RuntimeRequirements(internet_before_load=False, - network_before_load=False, - requires_internet=False, - requires_network=False, - no_internet_fallback=True, - no_network_fallback=True) - """ - return RuntimeRequirements(internet_before_load=False, - network_before_load=False, - requires_internet=False, - requires_network=False, - no_internet_fallback=True, - no_network_fallback=True) - - def translate(self, text, target=None, source=None): - return text - - def translate_dict(self, data, lang_tgt, lang_src="en"): + def runtime_requirements(self) -> RuntimeRequirements: + """ + Define the runtime requirements for this language translator. + + Returns: + RuntimeRequirements: Object indicating the runtime needs, including internet and network requirements. + """ + return RuntimeRequirements( + internet_before_load=False, + network_before_load=False, + requires_internet=False, + requires_network=False, + no_internet_fallback=True, + no_network_fallback=True + ) + + @abc.abstractmethod + def translate(self, text: str, target: Optional[str] = None, source: Optional[str] = None) -> str: + """ + Translate the given text from the source language to the target language. + + Args: + text (str): The text to translate. + target (Optional[str]): The target language code. If None, the internal language is used. + source (Optional[str]): The source language code. If None, the default language is used. + + Returns: + str: The translated text. + """ + + def translate_dict(self, data: Dict[str, Union[str, Dict, List]], lang_tgt: str, lang_src: str = "en") -> Dict[str, Union[str, Dict, List]]: + """ + Translate the values in a dictionary from one language to another. + + Args: + data (Dict[str, Union[str, Dict, List]]): The dictionary containing text to translate. + lang_tgt (str): The target language code. + lang_src (str): The source language code. + + Returns: + Dict[str, Union[str, Dict, List]]: The dictionary with translated values. + """ for k, v in data.items(): if isinstance(v, dict): data[k] = self.translate_dict(v, lang_tgt, lang_src) @@ -126,7 +143,18 @@ def translate_dict(self, data, lang_tgt, lang_src="en"): data[k] = self.translate_list(v, lang_tgt, lang_src) return data - def translate_list(self, data, lang_tgt, lang_src="en"): + def translate_list(self, data: List[Union[str, Dict, List]], lang_tgt: str, lang_src: str = "en") -> List[Union[str, Dict, List]]: + """ + Translate the values in a list from one language to another. + + Args: + data (List[Union[str, Dict, List]]): The list containing text to translate. + lang_tgt (str): The target language code. + lang_src (str): The source language code. + + Returns: + List[Union[str, Dict, List]]: The list with translated values. + """ for idx, v in enumerate(data): if isinstance(v, dict): data[idx] = self.translate_dict(v, lang_tgt, lang_src) @@ -136,25 +164,27 @@ def translate_list(self, data, lang_tgt, lang_src="en"): data[idx] = self.translate_list(v, lang_tgt, lang_src) return data - @property - def available_languages(self) -> set: + @property # TODO - make abstract method in future releases (mandatory for plugins to implement) + def available_languages(self) -> Set[str]: """ Return languages supported by this translator implementation in this state. Any language in this set should be translatable to any other language in the set. This property should be overridden by the derived class to advertise what languages that engine supports. Returns: - set: supported languages + Set[str]: A set of language codes supported by this translator. """ return set() - def supported_translations(self, source_lang: str = None) -> set: + # TODO - make abstract method in future releases (mandatory for plugins to implement) + def supported_translations(self, source_lang: Optional[str] = None) -> Set[str]: """ - Return valid target languages we can translate `source_lang` to. - This method should be overridden by the derived class. + Get the set of target languages to which the source language can be translated. + Args: - source_lang: ISO 639-1 source language code + source_lang (Optional[str]): The source language code. + Returns: - set of ISO 639-1 languages the source language can be translated to + Set[str]: A set of language codes that the source language can be translated to. """ return self.available_languages diff --git a/ovos_plugin_manager/templates/solvers.py b/ovos_plugin_manager/templates/solvers.py index 92522b9d..dfa8f481 100644 --- a/ovos_plugin_manager/templates/solvers.py +++ b/ovos_plugin_manager/templates/solvers.py @@ -2,161 +2,393 @@ # QuestionSolver Improvements and other solver classes are OVOS originals licensed under Apache 2.0 import abc -from typing import Optional, List, Iterable, Tuple +import inspect +from functools import wraps, lru_cache +from typing import Optional, List, Iterable, Tuple, Dict, Union from json_database import JsonStorageXDG -from ovos_plugin_manager.language import OVOSLangTranslationFactory -from ovos_utils.log import LOG +from ovos_utils import flatten_list +from ovos_utils.log import LOG, log_deprecation from ovos_utils.xdg_utils import xdg_cache_home from quebra_frases import sentence_tokenize +from ovos_plugin_manager.language import OVOSLangTranslationFactory, OVOSLangDetectionFactory +from ovos_plugin_manager.templates.language import LanguageTranslator, LanguageDetector + + +def auto_translate(translate_keys: List[str], translate_str_args=True): + """ Decorator to ensure all kwargs in 'translate_keys' are translated to self.default_lang. + data returned by the decorated function will be translated back to original language + NOTE: not meant to be used outside solver plugins""" + + def func_decorator(func): + + @wraps(func) + def func_wrapper(*args, **kwargs): + solver: AbstractSolver = args[0] + # check if translation is enabled + if not solver.enable_tx: + return func(*args, **kwargs) + + lang = kwargs.get("lang") + # check if translation can be skipped + if any([lang is None, + lang == solver.default_lang, + lang in solver.supported_langs]): + LOG.debug(f"skipping translation, 'lang': {lang} is supported by {func}") + return func(*args, **kwargs) + + # translate string arguments + if translate_str_args: + args = list(args) + for idx, arg in enumerate(args): + if isinstance(arg, str): + LOG.debug( + f"translating string argument with index: '{idx}' from {lang} to {solver.default_lang} for func: {func}") + args[idx] = _do_tx(solver, arg, + source_lang=lang, + target_lang=solver.default_lang) + + # translate input keys + for k in translate_keys: + v = kwargs.get(k) + if not v: + continue + kwargs[k] = _do_tx(solver, v, + source_lang=lang, + target_lang=solver.default_lang) + + out = func(*args, **kwargs) + + # reverse translate + return _do_tx(solver, out, + source_lang=solver.default_lang, + target_lang=lang) + + return func_wrapper + + return func_decorator + + +def auto_detect_lang(text_keys: List[str]): + """ Decorator to auto detect language if needed + NOTE: requires "lang" argument, not meant to be used outside solver plugins""" + + def func_decorator(func): + + @wraps(func) + def func_wrapper(*args, **kwargs): + solver: AbstractSolver = args[0] + + # detect language if needed + lang = kwargs.get("lang") + if lang is None: + LOG.debug(f"'lang' missing in kwargs for func: {func}") + for k in text_keys: + v = kwargs.get(k) + if isinstance(v, str): + lang = solver.detect_language(v) + LOG.debug(f"detected 'lang': {lang} in key: '{k}' for func: {func}") + break + else: + for idx, v in enumerate(args): + if isinstance(v, str) and len(v.split(" ")) > 1: + lang = solver.detect_language(v) + LOG.debug(f"detected 'lang': {lang} in argument '{idx}' for func: {func}") + + kwargs["lang"] = lang + return func(*args, **kwargs) + + return func_wrapper + + return func_decorator + + +def _deprecate_context2lang(): + """Decorator to deprecate the 'context' kwarg and replace it with 'lang'. + NOTE: can only be used in methods that accept "lang" as argument""" + + def func_decorator(func): + + @wraps(func) + def func_wrapper(*args, **kwargs): + + # Inspect the function signature to ensure it has both 'lang' and 'context' parameters + signature = inspect.signature(func) + params = signature.parameters + + if "context" in kwargs: + # NOTE: deprecate this at same time we + # standardize plugin namespaces to opm.XXX + log_deprecation("'context' kwarg has been deprecated, " + "please pass 'lang' as it's own kwarg instead", "0.1.0") + if "lang" in kwargs["context"] and "lang" not in kwargs: + kwargs["lang"] = kwargs["context"]["lang"] + + # ensure valid kwargs + if "lang" not in params and "lang" in kwargs: + kwargs.pop("lang") + if "context" not in params and "context" in kwargs: + kwargs.pop("context") + return func(*args, **kwargs) + + return func_wrapper + + return func_decorator + class AbstractSolver: - # these are defined by the plugin developer - priority = 50 - enable_tx = False - enable_cache = False - - def __init__(self, config=None, translator=None, *args, **kwargs): - if args or kwargs: - LOG.warning("solver plugins init signature changed, please update to accept config=None, translator=None. " - "an exception will be raised in next stable release") - for arg in args: - if isinstance(arg, str): - kwargs["name"] = arg - if isinstance(arg, int): - kwargs["priority"] = arg - if "priority" in kwargs: - self.priority = kwargs["priority"] - if "enable_tx" in kwargs: - self.enable_tx = kwargs["enable_tx"] - if "enable_cache" in kwargs: - self.enable_cache = kwargs["enable_cache"] + """Base class for solvers that perform various NLP tasks.""" + + def __init__(self, config=None, + translator: Optional[LanguageTranslator] = None, + detector: Optional[LanguageDetector] = None, + priority=50, + enable_tx=False, + enable_cache=False, + internal_lang: Optional[str] = None, + *args, **kwargs): + self.priority = priority + self.enable_tx = enable_tx + self.enable_cache = enable_cache self.config = config or {} self.supported_langs = self.config.get("supported_langs") or [] - self.default_lang = self.config.get("lang", "en") + self.default_lang = internal_lang or self.config.get("lang", "en") if self.default_lang not in self.supported_langs: self.supported_langs.insert(0, self.default_lang) - self.translator = translator or OVOSLangTranslationFactory.create() + self._translator = translator or OVOSLangTranslationFactory.create() if self.enable_tx else None + self._detector = detector or OVOSLangDetectionFactory.create() if self.enable_tx else None + LOG.debug(f"{self.__class__.__name__} default language: {self.default_lang}") + + @property + def detector(self): + """ language detector, lazy init on first access""" + if not self._detector: + # if it's being used, there is no recovery, do not try: except: + self._detector = OVOSLangDetectionFactory.create() + return self._detector + + @detector.setter + def detector(self, val): + self._detector = val + + @property + def translator(self): + """ language translator, lazy init on first access""" + if not self._translator: + # if it's being used, there is no recovery, do not try: except: + self._translator = OVOSLangTranslationFactory.create() + return self._translator + + @translator.setter + def translator(self, val): + self._translator = val @staticmethod - def sentence_split(text: str, max_sentences: int=25) -> List[str]: - return sentence_tokenize(text)[:max_sentences] + def sentence_split(text: str, max_sentences: int = 25) -> List[str]: + """ + Split text into sentences. - def _get_user_lang(self, context: Optional[dict] = None, - lang: Optional[str] = None) -> str: - context = context or {} - lang = lang or context.get("lang") or self.default_lang - lang = lang.split("-")[0] - return lang + :param text: Input text. + :param max_sentences: Maximum number of sentences to return. + :return: List of sentences. + """ + try: + # sentence_tokenize occasionally has issues with \n for some reason + return flatten_list([sentence_tokenize(t) + for t in text.split("\n")])[:max_sentences] + except Exception as e: + LOG.exception(f"Error in sentence_split: {e}") + return [text] + + @lru_cache(maxsize=128) + def detect_language(self, text: str) -> str: + """ + Detect the language of the input text. + + :param text: Input text. + :return: Detected language code. + """ + return self.detector.detect(text) - def _tx_query(self, query: str, - context: Optional[dict] = None, lang: Optional[str] = None): - if not self.enable_tx: - return query, context, lang - context = context or {} - lang = user_lang = self._get_user_lang(context, lang) + @lru_cache(maxsize=128) + def translate(self, text: str, + target_lang: Optional[str] = None, + source_lang: Optional[str] = None) -> str: + """ + Translate text from source_lang to target_lang. - # translate input to default lang - if user_lang not in self.supported_langs: - lang = self.default_lang - query = self.translator.translate(query, lang, user_lang) + :param text: Input text. + :param target_lang: Target language code. + :param source_lang: Source language code. + :return: Translated text. + """ + source_lang = source_lang or self.detect_language(text) + target_lang = target_lang or self.default_lang + if source_lang.split("-")[0] == target_lang.split("-")[0]: + return text # skip translation + return self.translator.translate(text, + target=target_lang, + source=source_lang) - context["lang"] = lang + def translate_list(self, data: List[str], + target_lang: Optional[str] = None, + source_lang: Optional[str] = None) -> List[str]: + """ + Translate a list of strings from source_lang to target_lang. - return query, context, lang + :param data: List of strings. + :param target_lang: Target language code. + :param source_lang: Source language code. + :return: List of translated strings. + """ + return self.translator.translate_list(data, + lang_tgt=target_lang, + lang_src=source_lang) + + def translate_dict(self, data: Dict[str, str], + target_lang: Optional[str] = None, + source_lang: Optional[str] = None) -> Dict[str, str]: + """ + Translate a dictionary of strings from source_lang to target_lang. + + :param data: Dictionary of strings. + :param target_lang: Target language code. + :param source_lang: Source language code. + :return: Dictionary of translated strings. + """ + return self.translator.translate_dict(data, + lang_tgt=target_lang, + lang_src=source_lang) def shutdown(self): - """ module specific shutdown method """ + """Module specific shutdown method.""" pass class QuestionSolver(AbstractSolver): - """free form unscontrained spoken question solver - handling automatic translation back and forth as needed""" - - def __init__(self, config=None, translator=None, *args, **kwargs): - super().__init__(config, translator, *args, **kwargs) + """ + A solver for free-form, unconstrained spoken questions that handles automatic translation as needed. + """ + + def __init__(self, config: Optional[Dict] = None, + translator: Optional[LanguageTranslator] = None, + detector: Optional[LanguageDetector] = None, + priority: int = 50, + enable_tx: bool = False, + enable_cache: bool = False, + internal_lang: Optional[str] = None, + *args, **kwargs): + """ + Initialize the QuestionSolver. + + :param config: Optional configuration dictionary. + :param translator: Optional language translator. + :param detector: Optional language detector. + :param priority: Priority of the solver. + :param enable_tx: Flag to enable translation. + :param enable_cache: Flag to enable caching. + """ + super().__init__(config, translator, detector, priority, + enable_tx, enable_cache, internal_lang, + *args, **kwargs) name = kwargs.get("name") or self.__class__.__name__ if self.enable_cache: # cache contains raw data self.cache = JsonStorageXDG(name + "_data", xdg_folder=xdg_cache_home(), - subfolder="neon_solvers") + subfolder="ovos_solvers") # spoken cache contains dialogs self.spoken_cache = JsonStorageXDG(name, xdg_folder=xdg_cache_home(), - subfolder="neon_solvers") + subfolder="ovos_solvers") else: self.cache = self.spoken_cache = {} # plugin methods to override @abc.abstractmethod - def get_spoken_answer(self, query: str, - context: Optional[dict] = None) -> str: + def get_spoken_answer(self, query: str, lang: Optional[str] = None) -> str: """ - query assured to be in self.default_lang - return a single sentence text response + Obtain the spoken answer for a given query. + + :param query: The query text. + :param lang: Optional language code. + :return: The spoken answer as a text response. """ raise NotImplementedError - def stream_utterances(self, query: str, - context: Optional[dict] = None) -> Iterable[str]: - """streaming api, yields utterances as they become available - each utterance can be sent to TTS before we have a full answer - this is particularly helpful with LLMs""" - ans = self.get_spoken_answer(query, context) + @_deprecate_context2lang() + def stream_utterances(self, query: str, lang: Optional[str] = None) -> Iterable[str]: + """ + Stream utterances for the given query as they become available. + + :param query: The query text. + :param lang: Optional language code. + :return: An iterable of utterances. + """ + ans = _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang) for utt in self.sentence_split(ans): yield utt - def get_data(self, query: str, - context: Optional[dict] = None) -> dict: + @_deprecate_context2lang() + def get_data(self, query: str, lang: Optional[str] = None) -> Optional[dict]: """ - query assured to be in self.default_lang - return a dict response + Retrieve data for the given query. + + :param query: The query text. + :param lang: Optional language code. + :return: A dictionary containing the answer. """ - return {"answer": self.get_spoken_answer(query, context)} + return {"answer": _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang)} - def get_image(self, query: str, - context: Optional[dict] = None) -> str: + @_deprecate_context2lang() + def get_image(self, query: str, lang: Optional[str] = None) -> Optional[str]: """ - query assured to be in self.default_lang - return path/url to a single image to acompany spoken_answer + Get the path or URL to an image associated with the query. + + :param query: The query text + :param lang: Optional language code. + :return: The path or URL to a single image. """ return None - def get_expanded_answer(self, query: str, - context: Optional[dict] = None) -> List[dict]: + @_deprecate_context2lang() + def get_expanded_answer(self, query: str, lang: Optional[str] = None) -> List[dict]: """ - query assured to be in self.default_lang - return a list of ordered steps to expand the answer, eg, "tell me more" - { - "title": "optional", - "summary": "speak this", - "img": "optional/path/or/url - } - :return: + Get an expanded list of steps to elaborate on the answer. + + :param query: The query text + :param lang: Optional language code. + :return: A list of dictionaries with each step containing a title, summary, and optional image. """ return [{"title": query, - "summary": self.get_spoken_answer(query, context), - "img": self.get_image(query, context)}] + "summary": _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang), + "img": _call_with_sanitized_kwargs(self.get_image, query, lang=lang)}] # user facing methods - def search(self, query: str, - context: Optional[dict] = None, lang: Optional[str] = None) -> dict: + @_deprecate_context2lang() + @auto_detect_lang(text_keys=["query"]) + @auto_translate(translate_keys=["query"]) + def search(self, query: str, lang: Optional[str] = None) -> dict: """ - cache and auto translate query if needed - returns translated response from self.get_data + Perform a search with automatic translation and caching. + + NOTE: "lang" assured to be in self.supported_langs, + otherwise "query" automatically translated to self.default_lang. + If translations happens, the returned value of this method will also + be automatically translated back + + :param query: The query text. + :param lang: Optional language code. + :return: The data dictionary retrieved from the cache or computed anew. """ - user_lang = self._get_user_lang(context, lang) - query, context, lang = self._tx_query(query, context, lang) # read from cache if self.enable_cache and query in self.cache: data = self.cache[query] else: # search data try: - data = self.get_data(query, context) + data = _call_with_sanitized_kwargs(self.get_data, query, lang=lang) except: return {} @@ -164,215 +396,352 @@ def search(self, query: str, if self.enable_cache: self.cache[query] = data self.cache.store() - - # translate english output to user lang - if self.enable_tx and user_lang not in self.supported_langs: - return self.translator.translate_dict(data, user_lang, lang) return data - def visual_answer(self, query: str, - context: Optional[dict] = None, lang: Optional[str] = None) -> str: + @_deprecate_context2lang() + @auto_detect_lang(text_keys=["query"]) + @auto_translate(translate_keys=["query"]) + def visual_answer(self, query: str, lang: Optional[str] = None) -> str: """ - cache and auto translate query if needed - returns image that answers query - """ - query, context, lang = self._tx_query(query, context, lang) - return self.get_image(query, context) + Retrieve the image associated with the query with automatic translation and caching. + + NOTE: "lang" assured to be in self.supported_langs, + otherwise "query" automatically translated to self.default_lang. + If translations happens, the returned value of this method will also + be automatically translated back - def spoken_answer(self, query: str, - context: Optional[dict] = None, lang: Optional[str] = None) -> str: + :param query: The query text. + :param lang: Optional language code. + :return: The path or URL to the image. """ - cache and auto translate query if needed - returns chunked and translated response from self.get_spoken_answer + return _call_with_sanitized_kwargs(self.get_image, query, lang=lang) + + @_deprecate_context2lang() + @auto_detect_lang(text_keys=["query"]) + @auto_translate(translate_keys=["query"]) + def spoken_answer(self, query: str, lang: Optional[str] = None) -> str: """ - user_lang = self._get_user_lang(context, lang) - query, context, lang = self._tx_query(query, context, lang) + Retrieve the spoken answer for the query with automatic translation and caching. + NOTE: "lang" assured to be in self.supported_langs, + otherwise "query" automatically translated to self.default_lang. + If translations happens, the returned value of this method will also + be automatically translated back + + :param query: The query text. + :param lang: Optional language code. + :return: The spoken answer as a text response. + """ # get answer if self.enable_cache and query in self.spoken_cache: # read from cache summary = self.spoken_cache[query] else: - summary = self.get_spoken_answer(query, context) + + summary = _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang) # save to cache if self.enable_cache: self.spoken_cache[query] = summary self.spoken_cache.store() + return summary - # summarize - if summary: - # translate english output to user lang - if self.enable_tx and user_lang not in self.supported_langs: - return self.translator.translate(summary, user_lang, lang) - else: - return summary - - def long_answer(self, query: str, - context: Optional[dict] = None, lang: Optional[str] = None) -> List[dict]: - """ - return a list of ordered steps to expand the answer, eg, "tell me more" - step0 is always self.spoken_answer and self.get_image - { - "title": "optional", - "summary": "speak this", - "img": "optional/path/or/url - } - :return: + @_deprecate_context2lang() + @auto_detect_lang(text_keys=["query"]) + @auto_translate(translate_keys=["query"]) + def long_answer(self, query: str, lang: Optional[str] = None) -> List[dict]: """ - user_lang = self._get_user_lang(context, lang) - query, context, lang = self._tx_query(query, context, lang) - steps = self.get_expanded_answer(query, context) + Retrieve a detailed list of steps to expand the answer. + NOTE: "lang" assured to be in self.supported_langs, + otherwise "query" automatically translated to self.default_lang. + If translations happens, the returned value of this method will also + be automatically translated back + + :param query: The query text. + :param lang: Optional language code. + :return: A list of steps to elaborate on the answer, with each step containing a title, summary, and optional image. + """ + steps = _call_with_sanitized_kwargs(self.get_expanded_answer, query, lang=lang) # use spoken_answer as last resort if not steps: - summary = self.get_spoken_answer(query, context) + summary = _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang) if summary: - img = self.get_image(query, context) + img = _call_with_sanitized_kwargs(self.get_image, query, lang=lang) steps = [{"title": query, "summary": step0, "img": img} for step0 in self.sentence_split(summary, -1)] - - # translate english output to user lang - if self.enable_tx and user_lang not in self.supported_langs: - return self.translator.translate_list(steps, user_lang, lang) return steps -class TldrSolver(AbstractSolver): - """perform NLP summarization task, - handling automatic translation back and forth as needed""" +class CorpusSolver(QuestionSolver): + """Retrieval based question solver""" - # plugin methods to override + def __init__(self, config=None, + translator: Optional[LanguageTranslator] = None, + detector: Optional[LanguageDetector] = None, + priority: int = 50, + enable_tx: bool = False, + enable_cache: bool = False, + *args, **kwargs): + super().__init__(config, translator, detector, + priority, enable_tx, enable_cache, + *args, **kwargs) + LOG.debug(f"corpus presumed to be in language: {self.default_lang}") + + @abc.abstractmethod + def load_corpus(self, corpus: List[str]): + """index the provided list of sentences""" + + @abc.abstractmethod + def query(self, query: str, lang: Optional[str], k: int = 3) -> Iterable[Tuple[str, float]]: + """return top_k matches from indexed corpus""" + + @auto_detect_lang(text_keys=["query"]) + @auto_translate(translate_keys=["query"]) + def retrieve_from_corpus(self, query: str, k: int = 3, lang: Optional[str] = None) -> List[Tuple[float, str]]: + """return top_k matches from indexed corpus""" + res = [] + for doc, score in self.query(query, lang, k=k): + LOG.debug(f"Rank {len(res) + 1} (score: {score}): {doc}") + if self.config.get("min_conf"): + if score >= self.config["min_conf"]: + res.append((score, doc)) + else: + res.append((score, doc)) + return res + + @auto_detect_lang(text_keys=["query"]) + @auto_translate(translate_keys=["query"]) + def get_spoken_answer(self, query: str, lang: Optional[str] = None) -> str: + # Query the corpus + answers = [a[1] for a in self.retrieve_from_corpus(query, lang=lang, + k=self.config.get("n_answer", 1))] + if answers: + return ". ".join(answers[:self.config.get("n_answer", 1)]) + + +class QACorpusSolver(CorpusSolver): + def __init__(self, config=None, + translator: Optional[LanguageTranslator] = None, + detector: Optional[LanguageDetector] = None, + priority: int = 50, + enable_tx: bool = False, + enable_cache: bool = False, + *args, **kwargs): + self.answers = {} + super().__init__(config, translator, detector, + priority, enable_tx, enable_cache, + *args, **kwargs) + + def load_corpus(self, corpus: Dict): + self.answers = corpus + super().load_corpus(list(self.answers.keys())) + + @auto_detect_lang(text_keys=["query"]) + @auto_translate(translate_keys=["query"]) + def retrieve_from_corpus(self, query: str, k: int = 1, lang: Optional[str] = None) -> List[Tuple[float, str]]: + res = [] + for doc, score in super().retrieve_from_corpus(query, k, lang): + LOG.debug(f"Answer {len(res) + 1} (score: {score}): {self.answers[doc]}") + res.append((score, self.answers[doc])) + return res + + +class TldrSolver(AbstractSolver): + """ + Solver for performing NLP summarization tasks, + handling automatic translation as needed. + """ @abc.abstractmethod def get_tldr(self, document: str, - context: Optional[dict] = None) -> str: + lang: Optional[str] = None) -> str: """ - document assured to be in self.default_lang - returns summary of provided document + Summarize the provided document. + + :param document: The text of the document to summarize, assured to be in the default language. + :param lang: Optional language code. + :return: A summary of the provided document. """ raise NotImplementedError # user facing methods - def tldr(self, document: str, - context: Optional[dict] = None, lang: Optional[str] = None) -> str: - """ - cache and auto translate query if needed - returns summary of provided document + + @_deprecate_context2lang() + @auto_detect_lang(text_keys=["document"]) + @auto_translate(translate_keys=["document"]) + def tldr(self, document: str, lang: Optional[str] = None) -> str: """ - user_lang = self._get_user_lang(context, lang) - document, context, lang = self._tx_query(document, context, lang) + Summarize the provided document with automatic translation and caching if needed. - # summarize - tldr = self.get_tldr(document, context) + NOTE: "lang" assured to be in self.supported_langs, + otherwise "document" automatically translated to self.default_lang. + If translations happens, the returned value of this method will also + be automatically translated back - # translate output to user lang - if self.enable_tx and user_lang not in self.supported_langs: - return self.translator.translate(tldr, user_lang, lang) - return tldr + :param document: The text of the document to summarize. + :param lang: Optional language code. + :return: A summary of the provided document. + """ + # summarize + return _call_with_sanitized_kwargs(self.get_tldr, document, lang=lang) class EvidenceSolver(AbstractSolver): - """perform NLP reading comprehension task, - handling automatic translation back and forth as needed""" - - # plugin methods to override + """ + Solver for NLP reading comprehension tasks, + handling automatic translation as needed. + """ @abc.abstractmethod def get_best_passage(self, evidence: str, question: str, - context: Optional[dict] = None) -> str: + lang: Optional[str] = None) -> str: """ - evidence and question assured to be in self.default_lang - returns summary of provided document + Extract the best passage from evidence that answers the given question. + + :param evidence: The text containing the evidence, assured to be in the default language. + :param question: The question to answer, assured to be in the default language. + :param lang: Optional language code. + :return: The passage from the evidence that best answers the question. """ raise NotImplementedError # user facing methods + @_deprecate_context2lang() + @auto_detect_lang(text_keys=["evidence", "question"]) + @auto_translate(translate_keys=["evidence", "question"]) def extract_answer(self, evidence: str, question: str, - context: Optional[dict] = None, lang: Optional[str] = None) -> str: - """ - cache and auto translate evidence and question if needed - returns passage from evidence that answers question + lang: Optional[str] = None) -> str: """ - user_lang = self._get_user_lang(context, lang) - evidence, context, lang = self._tx_query(evidence, context, lang) - question, context, lang = self._tx_query(question, context, lang) + Extract the best passage from evidence that answers the question with automatic translation and caching if needed. - # extract answer from doc - ans = self.get_best_passage(evidence, question, context) + NOTE: "lang" assured to be in self.supported_langs, + otherwise "evidence" and "question" are automatically translated to self.default_lang. + If translations happens, the returned value of this method will also + be automatically translated back - # translate output to user lang - if self.enable_tx and user_lang not in self.supported_langs: - return self.translator.translate(ans, user_lang, lang) - return ans + :param evidence: The text containing the evidence. + :param question: The question to answer. + :param lang: Optional language code. + :return: The passage from the evidence that answers the question. + """ + # extract answer from doc + return self.get_best_passage(evidence, question, lang=lang) class MultipleChoiceSolver(AbstractSolver): - """ select best answer from question + multiple choice - handling automatic translation back and forth as needed""" - - # plugin methods to override + """ + Solver for selecting the best answer from a question with multiple choices, + handling automatic translation as needed. + """ - # TODO - make abstract in the future, - # just giving some time buffer to update existing - # plugins in the wild missing this method - #@abc.abstractmethod + @abc.abstractmethod def rerank(self, query: str, options: List[str], - context: Optional[dict] = None) -> List[Tuple[float, str]]: + lang: Optional[str] = None, + return_index: bool = False) -> List[Tuple[float, Union[str, int]]]: """ - rank options list, returning a list of tuples (score, text) + Rank the provided options based on the query. + + :param query: The query text, assured to be in the default language. + :param options: A list of answer options, each assured to be in the default language. + :param lang: Optional language code. + :param return_index: If True, return the index of the best option; otherwise, return the best option text. + :return: A list of tuples where each tuple contains a score and the corresponding option text, sorted by score. """ raise NotImplementedError + @_deprecate_context2lang() + @auto_detect_lang(text_keys=["query", "options"]) + @auto_translate(translate_keys=["query", "options"]) def select_answer(self, query: str, options: List[str], - context: Optional[dict] = None) -> str: + lang: Optional[str] = None, + return_index: bool = False) -> Union[str, int]: """ - query and options assured to be in self.default_lang - return best answer from options list - """ - return self.rerank(query, options, context)[0][1] + Select the best answer from the provided options based on the query with automatic translation and caching if needed. - # user facing methods - def solve(self, query: str, options: List[str], - context: Optional[dict] = None, lang: Optional[str] = None) -> str: - """ - cache and auto translate query and options if needed - returns best answer from provided options - """ - user_lang = self._get_user_lang(context, lang) - query, context, lang = self._tx_query(query, context, lang) - opts = [self.translator.translate(opt, lang, user_lang) - for opt in options] + NOTE: "lang" assured to be in self.supported_langs, + otherwise "query" and "options" are automatically translated to self.default_lang. + If translations happens, the returned value of this method will also + be automatically translated back - # select best answer - ans = self.select_answer(query, opts, context) - - idx = opts.index(ans) - return options[idx] + :param query: The query text. + :param options: A list of answer options. + :param lang: Optional language code. + :param return_index: If True, return the index of the best option; otherwise, return the best option text. + :return: The best answer from the options list, or the index of the best option if `return_index` is True. + """ + return self.rerank(query, options, lang=lang, return_index=return_index)[0][1] class EntailmentSolver(AbstractSolver): """ select best answer from question + multiple choice handling automatic translation back and forth as needed""" - # plugin methods to override - @abc.abstractmethod def check_entailment(self, premise: str, hypothesis: str, - context: Optional[dict] = None) -> bool: + lang: Optional[str] = None) -> bool: """ - premise and hyopithesis assured to be in self.default_lang - return Bool, True if premise entails the hypothesis False otherwise + Check if the premise entails the hypothesis. + + :param premise: The premise text, assured to be in the default language. + :param hypothesis: The hypothesis text, assured to be in the default language. + :param lang: Optional language code. + :return: True if the premise entails the hypothesis; False otherwise. """ raise NotImplementedError # user facing methods - def entails(self, premise: str, hypothesis: str, - context: Optional[dict] = None, lang: Optional[str] = None) -> bool: + @_deprecate_context2lang() + @auto_detect_lang(text_keys=["premise", "hypothesis"]) + @auto_translate(translate_keys=["premise", "hypothesis"]) + def entails(self, premise: str, hypothesis: str, lang: Optional[str] = None) -> bool: """ - cache and auto translate premise and hypothesis if needed - return Bool, True if premise entails the hypothesis False otherwise + Determine if the premise entails the hypothesis with automatic translation and caching if needed. + + NOTE: "lang" assured to be in self.supported_langs, + otherwise "premise" and "hypothesis" are automatically translated to self.default_lang. + If translations happens, the returned value of this method will also + be automatically translated back + + :param premise: The premise text. + :param hypothesis: The hypothesis text. + :param lang: Optional language code. + :return: True if the premise entails the hypothesis; False otherwise. """ - premise, context, lang = self._tx_query(premise, context, lang) - hypothesis, context, lang = self._tx_query(hypothesis, context, lang) # check for entailment - return self.check_entailment(premise, hypothesis) + return self.check_entailment(premise, hypothesis, lang=lang) + + +def _do_tx(solver, data, source_lang, target_lang): + if isinstance(data, str): + return solver.translate(data, + source_lang=source_lang, target_lang=target_lang) + elif isinstance(data, list): + for idx, e in enumerate(data): + data[idx] = _do_tx(solver, e, source_lang=source_lang, target_lang=target_lang) + elif isinstance(data, dict): + for k, v in data.items(): + data[k] = _do_tx(solver, v, source_lang=source_lang, target_lang=target_lang) + elif isinstance(data, tuple) and len(data) == 2: + if isinstance(data[0], str): + a = _do_tx(solver, data[0], source_lang=source_lang, target_lang=target_lang) + else: + a = data[0] + if isinstance(data[1], str): + b = _do_tx(solver, data[1], source_lang=source_lang, target_lang=target_lang) + else: + b = data[1] + return (a, b) + return data + + +def _call_with_sanitized_kwargs(func, *args, lang: Optional[str] = None): + # Inspect the function signature to ensure it has both 'lang' and 'context' parameters + params = inspect.signature(func).parameters + kwargs = {} + if "lang" in params: + # new style - only lang is passed + kwargs["lang"] = lang + elif "context" in kwargs: + # old style - when plugins received context only + kwargs["context"]["lang"] = lang + return func(*args, **kwargs) diff --git a/requirements/test.txt b/requirements/test.txt index 7694cc16..4373cf13 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -2,4 +2,5 @@ pytest pytest-timeout pytest-cov ovos-translate-server-plugin +ovos-classifiers ovos-utils>=0.1.0a8 \ No newline at end of file diff --git a/test/unittests/test_solver.py b/test/unittests/test_solver.py index c460f267..354b18af 100644 --- a/test/unittests/test_solver.py +++ b/test/unittests/test_solver.py @@ -1,8 +1,10 @@ import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock +from ovos_plugin_manager.templates.solvers import QuestionSolver, auto_detect_lang, auto_translate, _deprecate_context2lang, AbstractSolver from ovos_plugin_manager.utils import PluginTypes, PluginConfigTypes -from ovos_plugin_manager.templates.solvers import QuestionSolver + + # TODO: Test Tldr, Evidence, MultipleChoice, Entailment @@ -141,11 +143,11 @@ def test_translation(self): solver.translator.translate.return_value = "a wild translation appears" # no translation - ans = solver.spoken_answer("some query") + ans = solver.spoken_answer("some query", lang="en") solver.translator.translate.assert_not_called() # translation - ans = solver.spoken_answer("not english", context={"lang": "unk"}) + ans = solver.spoken_answer("not english", lang="unk") solver.translator.translate.assert_called() @@ -398,3 +400,76 @@ def test_get_supported_langs(self, get_supported_languages): get_reading_comprehension_solver_supported_langs get_reading_comprehension_solver_supported_langs() get_supported_languages.assert_called_once_with(self.PLUGIN_TYPE) + + +class TestAutoTranslate(unittest.TestCase): + def setUp(self): + self.solver = AbstractSolver(enable_tx=True, default_lang='en') + self.solver.translate = MagicMock(side_effect=lambda text, source_lang=None, target_lang=None: text[ + ::-1] if source_lang and target_lang else text) + + def test_auto_translate_decorator(self): + @auto_translate(translate_keys=['text']) + def test_func(solver, text, lang=None): + return text[::-1] + + result = test_func(self.solver, 'hello', lang='es') + self.assertEqual(result, 'olleh') # 'hello' reversed due to mock translation + + def test_auto_translate_no_translation(self): + @auto_translate(translate_keys=['text']) + def test_func(solver, text, lang=None): + return text + + result = test_func(self.solver, 'hello') + self.assertEqual(result, 'hello') + + +class TestAutoDetectLang(unittest.TestCase): + def setUp(self): + self.solver = AbstractSolver() + self.solver.detect_language = MagicMock(return_value='en') + + def test_auto_detect_lang_decorator(self): + self.solver.detector = Mock() + self.solver.detector.detect.return_value = "en" + + @auto_detect_lang(text_keys=['text']) + def test_func(solver, text, lang=None): + return lang + + result = test_func(self.solver, 'hello world') + self.assertEqual(result, 'en') + + def test_auto_detect_lang_with_lang(self): + @auto_detect_lang(text_keys=['text']) + def test_func(solver, text, lang=None): + return lang + + result = test_func(self.solver, 'hello', lang='es') + self.assertEqual(result, 'es') + + +class TestDeprecateContext2Lang(unittest.TestCase): + def setUp(self): + self.solver = AbstractSolver() + + def test_deprecate_context2lang(self): + @_deprecate_context2lang() + def test_func(solver, lang=None): + return lang + + result = test_func(self.solver, context={'lang': 'en'}) + self.assertEqual(result, 'en') + + def test_no_context(self): + @_deprecate_context2lang() + def test_func(solver, lang=None): + return lang + + result = test_func(self.solver, lang='fr') + self.assertEqual(result, 'fr') + + +if __name__ == '__main__': + unittest.main()