Skip to content

Commit

Permalink
Merge pull request #18 from raldone01/feat/support_custom_endpoints
Browse files Browse the repository at this point in the history
Support custom openai compatible endpoints
  • Loading branch information
sfortis authored Jul 17, 2024
2 parents d2e2c9f + 0460cd2 commit 537b576
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 43 deletions.
33 changes: 16 additions & 17 deletions custom_components/openai_tts/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,24 @@
from typing import Any
import voluptuous as vol
import logging
from urllib.parse import urlparse

from homeassistant import data_entry_flow
from homeassistant.config_entries import ConfigFlow
from homeassistant.helpers.selector import selector
from homeassistant.exceptions import HomeAssistantError

from .const import CONF_API_KEY, CONF_MODEL, CONF_VOICE, CONF_SPEED, DOMAIN, MODELS, VOICES
from .const import CONF_API_KEY, CONF_MODEL, CONF_VOICE, CONF_SPEED, CONF_URL, DOMAIN, MODELS, VOICES, UNIQUE_ID

_LOGGER = logging.getLogger(__name__)

class WrongAPIKey(HomeAssistantError):
"""Error to indicate no or wrong API key."""

async def validate_api_key(api_key: str):
"""Validate the API key format."""
if api_key is None:
raise WrongAPIKey("API key is required")
if not (51 <= len(api_key) <= 70):
raise WrongAPIKey("Invalid API key length")
def generate_unique_id(user_input: dict) -> str:
"""Generate a unique id from user input."""
url = urlparse(user_input[CONF_URL])
return f"{url.hostname}_{user_input[CONF_MODEL]}_{user_input[CONF_VOICE]}"

async def validate_user_input(user_input: dict):
"""Validate user input fields."""
await validate_api_key(user_input.get(CONF_API_KEY))
if user_input.get(CONF_MODEL) is None:
raise ValueError("Model is required")
if user_input.get(CONF_VOICE) is None:
Expand All @@ -35,22 +30,23 @@ class OpenAITTSConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for OpenAI TTS."""
VERSION = 1
data_schema = vol.Schema({
vol.Required(CONF_API_KEY): str,
vol.Optional(CONF_API_KEY): str,
vol.Optional(CONF_URL, default="https://api.openai.com/v1/audio/speech"): str,
vol.Optional(CONF_SPEED, default=1.0): vol.Coerce(float),
vol.Required(CONF_MODEL, default="tts-1"): selector({
"select": {
"options": MODELS,
"mode": "dropdown",
"sort": True,
"custom_value": False
"custom_value": True
}
}),
vol.Required(CONF_VOICE, default="shimmer"): selector({
"select": {
"options": VOICES,
"mode": "dropdown",
"sort": True,
"custom_value": False
"custom_value": True
}
})
})
Expand All @@ -61,14 +57,17 @@ async def async_step_user(self, user_input: dict[str, Any] | None = None):
if user_input is not None:
try:
await validate_user_input(user_input)
await self.async_set_unique_id(f"{user_input[CONF_VOICE]}_{user_input[CONF_MODEL]}")
unique_id = generate_unique_id(user_input)
user_input[UNIQUE_ID] = unique_id
await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured()
return self.async_create_entry(title="OpenAI TTS", data=user_input)
hostname = urlparse(user_input[CONF_URL]).hostname
return self.async_create_entry(title=f"OpenAI TTS ({hostname}, {user_input[CONF_MODEL]}, {user_input[CONF_VOICE]})", data=user_input)
except data_entry_flow.AbortFlow:
return self.async_abort(reason="already_configured")
except HomeAssistantError as e:
_LOGGER.exception(str(e))
errors["api_key"] = "wrong_api_key"
errors["base"] = str(e)
except ValueError as e:
_LOGGER.exception(str(e))
errors["base"] = str(e)
Expand Down
3 changes: 2 additions & 1 deletion custom_components/openai_tts/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
CONF_MODEL = 'model'
CONF_VOICE = 'voice'
CONF_SPEED = 'speed'
CONF_URL = 'url'
UNIQUE_ID = 'unique_id'
MODELS = ["tts-1", "tts-1-hd"]
VOICES = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
URL = "https://api.openai.com/v1/audio/speech"
10 changes: 7 additions & 3 deletions custom_components/openai_tts/manifest.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
{
"domain": "openai_tts",
"name": "OpenAI TTS",
"codeowners": ["@sfortis"],
"codeowners": [
"@sfortis"
],
"config_flow": true,
"dependencies": [],
"documentation": "https://github.com/sfortis/openai_tts/",
"iot_class": "cloud_polling",
"issue_tracker": "https://github.com/sfortis/openai_tts/issues",
"requirements": ["requests>=2.25.1"],
"version": "0.2.1"
"requirements": [
"requests>=2.25.1"
],
"version": "0.2.2"
}
19 changes: 10 additions & 9 deletions custom_components/openai_tts/openaitts_engine.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import requests

from .const import URL


class OpenAITTSEngine:

def __init__(self, api_key: str, voice: str, model: str, speed: int):
def __init__(self, api_key: str, voice: str, model: str, speed: int, url: str):
self._api_key = api_key
self._voice = voice
self._model = model
self._speed = speed
self._url = URL
self._url = url

def get_tts(self, text: str):
""" Makes request to OpenAI TTS engine to convert text into audio"""
headers: dict = {"Authorization": f"Bearer {self._api_key}"}
data: dict = {"model": self._model, "input": text, "voice": self._voice, "speed": self._speed}
headers: dict = {"Authorization": f"Bearer {self._api_key}"} if self._api_key else {}
data: dict = {
"model": self._model,
"input": text,
"voice": self._voice,
"response_format": "wav",
"speed": self._speed
}
return requests.post(self._url, headers=headers, json=data)

@staticmethod
def get_supported_langs() -> list:
"""Returns list of supported languages. Note: the model determines the provides language automatically."""
return ["af", "ar", "hy", "az", "be", "bs", "bg", "ca", "zh", "hr", "cs", "da", "nl", "en", "et", "fi", "fr", "gl", "de", "el", "he", "hi", "hu", "is", "id", "it", "ja", "kn", "kk", "ko", "lv", "lt", "mk", "ms", "mr", "mi", "ne", "no", "fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw", "sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy"]


7 changes: 4 additions & 3 deletions custom_components/openai_tts/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
"api_key": "Enter OpenAI API key.",
"speed": "Enter speed of the speech",
"model": "Select model to be used.",
"voice": "Select voice."
"voice": "Select voice.",
"url": "Enter the OpenAI-compatible endpoint. Optionally include a port number."
}
}
},
"error": {
"wrong_api_key": "Invalid API key. Please enter a valid API key.",
"already_configured": "This voice is already configured."
"already_configured": "This voice and endpoint are already configured."
},
"abort": {
"already_configured": "This voice is already configured."
"already_configured": "This voice and endpoint are already configured."
}
}
}
3 changes: 2 additions & 1 deletion custom_components/openai_tts/translations/cs.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"api_key": "Vlož OpenAI API klíč.",
"speed": "Vlož rychlost řeči.",
"model": "Vyber model k použití.",
"voice": "Vyber hlas."
"voice": "Vyber hlas.",
"url": "Zadejte koncový bod kompatibilní s OpenAI. Volitelně uveďte číslo portu."
}
}
},
Expand Down
24 changes: 24 additions & 0 deletions custom_components/openai_tts/translations/de.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"config": {
"step": {
"user": {
"title": "Füge eine Text zu Sprache Engine hinzu",
"description": "Gib Konfigurationsdaten ein. Schau in die Dokumentation für weitere Informationen.",
"data": {
"api_key": "Gib den OpenAI API Schlüssel ein.",
"speed": "Gib die Geschwindigkeit der Sprache ein",
"model": "Wähle das zu verwendende Modell.",
"voice": "Wähle eine Stimme.",
"url": "Gib den OpenAI-kompatiblen Endpunkt ein. Optional kann eine Portnummer angegeben werden."
}
}
},
"error": {
"wrong_api_key": "Ungültiger API Schlüssel. Bitte gib einen gültigen API Schlüssel ein.",
"already_configured": "Diese Stimme und Endpunkt sind bereits konfiguriert."
},
"abort": {
"already_configured": "Diese Stimme und Endpunkt sind bereits konfiguriert."
}
}
}
7 changes: 4 additions & 3 deletions custom_components/openai_tts/translations/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
"api_key": "Enter OpenAI API key.",
"speed": "Enter speed of the speech",
"model": "Select model to be used.",
"voice": "Select voice."
"voice": "Select voice.",
"url": "Enter the OpenAI-compatible endpoint. Optionally include a port number."
}
}
},
"error": {
"wrong_api_key": "Invalid API key. Please enter a valid API key.",
"already_configured": "This voice is already configured."
"already_configured": "This voice and endpoint are already configured."
},
"abort": {
"already_configured": "This voice is already configured."
"already_configured": "This voice and endpoint are already configured."
}
}
}
19 changes: 14 additions & 5 deletions custom_components/openai_tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,29 @@
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.entity import generate_entity_id
from .const import CONF_API_KEY, CONF_MODEL, CONF_SPEED, CONF_VOICE, DOMAIN
from .const import CONF_API_KEY, CONF_MODEL, CONF_SPEED, CONF_VOICE, CONF_URL, DOMAIN, UNIQUE_ID
from .openaitts_engine import OpenAITTSEngine
from homeassistant.exceptions import MaxLengthExceeded

_LOGGER = logging.getLogger(__name__)


async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up OpenAI Text-to-speech platform via config entry."""

api_key = None
if CONF_API_KEY in config_entry.data:
api_key = config_entry.data[CONF_API_KEY]

engine = OpenAITTSEngine(
config_entry.data[CONF_API_KEY],
api_key,
config_entry.data[CONF_VOICE],
config_entry.data[CONF_MODEL],
config_entry.data[CONF_SPEED],
config_entry.data[CONF_URL]
)
async_add_entities([OpenAITTSEntity(hass, config_entry, engine)])

Expand All @@ -39,7 +44,11 @@ def __init__(self, hass, config, engine):
self.hass = hass
self._engine = engine
self._config = config
self._attr_unique_id = f"{config.data[CONF_VOICE]}_{config.data[CONF_MODEL]}"

self._attr_unique_id = config.data.get(UNIQUE_ID)
if self._attr_unique_id is None:
# generate a legacy unique_id
self._attr_unique_id = f"{config.data[CONF_VOICE]}_{config.data[CONF_MODEL]}"
self.entity_id = generate_entity_id("tts.openai_tts_{}", config.data[CONF_VOICE], hass=hass)

@property
Expand Down Expand Up @@ -74,7 +83,7 @@ def get_tts_audio(self, message, language, options=None):
speech = self._engine.get_tts(message)

# The response should contain the audio file content
return "mp3", speech.content
return "wav", speech.content
except MaxLengthExceeded:
_LOGGER.error("Maximum length of the message exceeded")
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion hacs.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"name": "OpenAI TTS Speech Service",
"homeassistant": "2024.5.3",
"render_readme": true
"render_readme": true,
"version": "0.2.3"
}

0 comments on commit 537b576

Please sign in to comment.