diff --git a/.github/workflows/test-base.yaml b/.github/workflows/test-base.yaml new file mode 100644 index 00000000..8fe40ca2 --- /dev/null +++ b/.github/workflows/test-base.yaml @@ -0,0 +1,46 @@ +name: Test + +on: + workflow_call: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Cache Docker layers + uses: actions/cache@v3 + with: + path: /tmp/.buildx-cache + key: ${{ runner.os }}-buildx-${{ hashFiles('Pipfile.lock', 'compose/local/django/Dockerfile') }} + restore-keys: | + ${{ runner.os }}-buildx- + - name: Bake docker images + uses: docker/bake-action@v4 + with: + load: true + set: | + *.cache-from=type=local,src=/tmp/.buildx-cache + *.cache-to=type=local,dest=/tmp/.buildx-cache-new + files: docker-compose.local.yaml + + - name: Start services + run: docker compose -f docker-compose.local.yaml up -d --wait --no-build + + - name: Check migrations + run: make checkmigration + + # - name: Run tests + # run: make test-coverage + + # - name: Upload coverage report + # uses: codecov/codecov-action@v3 + + - name: Move cache + run: | + rm -rf /tmp/.buildx-cache + mv /tmp/.buildx-cache-new /tmp/.buildx-cache diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 00000000..2ad1168a --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,12 @@ +name: Test PR + +on: + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + uses: ./.github/workflows/test-base.yaml diff --git a/Makefile b/Makefile index aa31b0ac..209e88c2 100644 --- a/Makefile +++ b/Makefile @@ -31,6 +31,9 @@ logs: makemigrations: up docker exec django bash -c "python manage.py makemigrations" + +checkmigration: + docker compose -f $(docker_config_file) exec django bash -c "python manage.py makemigrations --check --dry-run" test: up docker exec django bash -c "python manage.py test --keepdb --parallel=$(nproc)" diff --git a/ayushma/migrations/0051_project_tts_engine.py b/ayushma/migrations/0051_project_tts_engine.py new file mode 100644 index 00000000..5ddc3372 --- /dev/null +++ b/ayushma/migrations/0051_project_tts_engine.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.6 on 2024-02-11 15:23 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ayushma", "0050_alter_chat_model_alter_project_model"), + ] + + operations = [ + migrations.AddField( + model_name="project", + name="tts_engine", + field=models.SmallIntegerField( + choices=[(1, "openai"), (2, "google")], default=2 + ), + ), + ] diff --git a/ayushma/models/chat.py b/ayushma/models/chat.py index c83d005c..425633ca 100644 --- a/ayushma/models/chat.py +++ b/ayushma/models/chat.py @@ -32,6 +32,7 @@ class ChatMessage(BaseModel): original_message = models.TextField(blank=True, null=True) language = models.CharField(max_length=10, blank=False, default="en") reference_documents = models.ManyToManyField(Document, blank=True) + # generated ayushma voice audio via TTS audio = models.FileField(blank=True, null=True) meta = models.JSONField(blank=True, null=True) temperature = models.FloatField(blank=True, null=True) diff --git a/ayushma/models/enums.py b/ayushma/models/enums.py index 56611c53..a7a7c9cb 100644 --- a/ayushma/models/enums.py +++ b/ayushma/models/enums.py @@ -19,6 +19,11 @@ class STTEngine(IntegerChoices): SELF_HOSTED = 3 +class TTSEngine(IntegerChoices): + OPENAI = (1, "openai") + GOOGLE = (2, "google") + + class FeedBackRating(IntegerChoices): HALLUCINATING = 1 WRONG = 2 diff --git a/ayushma/models/project.py b/ayushma/models/project.py index e2e53531..5ebcc76d 100644 --- a/ayushma/models/project.py +++ b/ayushma/models/project.py @@ -1,7 +1,7 @@ from django.contrib.postgres.fields import ArrayField from django.db import models -from ayushma.models.enums import ModelType, STTEngine +from ayushma.models.enums import ModelType, STTEngine, TTSEngine from ayushma.models.users import User from utils.models.base import BaseModel @@ -16,6 +16,9 @@ class Project(BaseModel): stt_engine = models.IntegerField( choices=STTEngine.choices, default=STTEngine.WHISPER ) + tts_engine = models.SmallIntegerField( + choices=TTSEngine.choices, default=TTSEngine.GOOGLE + ) model = models.IntegerField(choices=ModelType.choices, default=ModelType.GPT_3_5) preset_questions = ArrayField(models.TextField(), null=True, blank=True) is_default = models.BooleanField(default=False) diff --git a/ayushma/serializers/chat.py b/ayushma/serializers/chat.py index ec713884..e133e3cd 100644 --- a/ayushma/serializers/chat.py +++ b/ayushma/serializers/chat.py @@ -110,6 +110,8 @@ class ConverseSerializer(serializers.Serializer): stream = serializers.BooleanField(default=True) generate_audio = serializers.BooleanField(default=True) noonce = serializers.CharField(required=False) + transcript_start_time = serializers.FloatField(required=False) + transcript_end_time = serializers.FloatField(required=False) class ChatDetailSerializer(serializers.ModelSerializer): @@ -146,9 +148,11 @@ def get_chats(self, obj): ) return [ { - "messageType": ChatMessageType.USER - if thread_message.role == "user" - else ChatMessageType.AYUSHMA, + "messageType": ( + ChatMessageType.USER + if thread_message.role == "user" + else ChatMessageType.AYUSHMA + ), "message": thread_message.content[0].text.value, "reference_documents": thread_message.content[0].text.annotations, "language": "en", @@ -159,3 +163,8 @@ def get_chats(self, obj): chatmessages = ChatMessage.objects.filter(chat=obj).order_by("created_at") context = {"request": self.context.get("request")} return ChatMessageSerializer(chatmessages, many=True, context=context).data + + +class SpeechToTextSerializer(serializers.Serializer): + audio = serializers.FileField(required=True) + language = serializers.CharField(default="en") diff --git a/ayushma/serializers/project.py b/ayushma/serializers/project.py index 4a7d0ddf..5681b43a 100644 --- a/ayushma/serializers/project.py +++ b/ayushma/serializers/project.py @@ -25,6 +25,7 @@ class Meta: "modified_at", "description", "stt_engine", + "tts_engine", "model", "is_default", "display_preset_questions", diff --git a/ayushma/utils/converse.py b/ayushma/utils/converse.py index 2dacb0b1..19ea4aca 100644 --- a/ayushma/utils/converse.py +++ b/ayushma/utils/converse.py @@ -33,6 +33,7 @@ def converse_api( audio = request.data.get("audio") text = request.data.get("text") language = request.data.get("language") or "en" + try: service: Service = request.service except AttributeError: @@ -128,6 +129,11 @@ def converse_api( translated_text = transcript elif converse_type == "text": + if request.data.get("transcript_start_time") and request.data.get( + "transcript_end_time" + ): + stats["transcript_start_time"] = request.data["transcript_start_time"] + stats["transcript_end_time"] = request.data["transcript_end_time"] translated_text = text if language != "en": diff --git a/ayushma/utils/language_helpers.py b/ayushma/utils/language_helpers.py index c67e6161..460e2093 100644 --- a/ayushma/utils/language_helpers.py +++ b/ayushma/utils/language_helpers.py @@ -1,9 +1,13 @@ import re +from django.conf import settings from google.cloud import texttospeech from google.cloud import translate_v2 as translate +from openai import OpenAI from rest_framework.exceptions import APIException +from ayushma.models.enums import TTSEngine + def translate_text(target, text): try: @@ -37,31 +41,43 @@ def sanitize_text(text): return sanitized_text -def text_to_speech(text, language_code): +def text_to_speech(text, language_code, service): try: # in en-IN neural voice is not available if language_code == "en-IN": language_code = "en-US" - client = texttospeech.TextToSpeechClient() - text = sanitize_text(text) - synthesis_input = texttospeech.SynthesisInput(text=text) - - voice = texttospeech.VoiceSelectionParams( - language_code=language_code, name=language_code_voice_map[language_code] - ) - audio_config = texttospeech.AudioConfig( - audio_encoding=texttospeech.AudioEncoding.MP3 - ) - - response = client.synthesize_speech( - input=synthesis_input, - voice=voice, - audio_config=audio_config, - ) - - return response.audio_content + + if service == TTSEngine.GOOGLE: + client = texttospeech.TextToSpeechClient() + + synthesis_input = texttospeech.SynthesisInput(text=text) + + voice = texttospeech.VoiceSelectionParams( + language_code=language_code, name=language_code_voice_map[language_code] + ) + audio_config = texttospeech.AudioConfig( + audio_encoding=texttospeech.AudioEncoding.MP3 + ) + + response = client.synthesize_speech( + input=synthesis_input, + voice=voice, + audio_config=audio_config, + ) + + return response.audio_content + elif service == TTSEngine.OPENAI: + client = OpenAI(api_key=settings.OPENAI_API_KEY) + response = client.audio.speech.create( + model="tts-1-hd", + voice="nova", + input=text, + ) + return response.read() + else: + raise APIException("Service not supported") except Exception as e: print(e) return None diff --git a/ayushma/utils/openaiapi.py b/ayushma/utils/openaiapi.py index 10f7956c..2326d5ff 100644 --- a/ayushma/utils/openaiapi.py +++ b/ayushma/utils/openaiapi.py @@ -203,6 +203,7 @@ def handle_post_response( temperature, stats, language, + tts_engine, generate_audio=True, ): chat_message: ChatMessage = ChatMessage.objects.create( @@ -225,7 +226,9 @@ def handle_post_response( ayushma_voice = None if generate_audio: stats["tts_start_time"] = time.time() - ayushma_voice = text_to_speech(translated_chat_response, user_language) + ayushma_voice = text_to_speech( + translated_chat_response, user_language, tts_engine + ) stats["tts_end_time"] = time.time() url = None @@ -324,6 +327,8 @@ def converse( elif message.messageType == ChatMessageType.AYUSHMA: chat_history.append(AIMessage(content=f"Ayushma: {message.message}")) + tts_engine = chat.project.tts_engine + if not stream: lang_chain_helper = LangChainHelper( stream=False, @@ -347,6 +352,7 @@ def converse( temperature, stats, language, + tts_engine, generate_audio, ) @@ -404,6 +410,7 @@ def converse( temperature, stats, language, + tts_engine, generate_audio, ) diff --git a/ayushma/utils/speech_to_text.py b/ayushma/utils/speech_to_text.py index da649942..551ba73e 100644 --- a/ayushma/utils/speech_to_text.py +++ b/ayushma/utils/speech_to_text.py @@ -1,9 +1,9 @@ import os -import openai import requests from django.conf import settings from google.cloud import speech +from openai import OpenAI from ayushma.models.enums import STTEngine @@ -14,19 +14,14 @@ def __init__(self, api_key, language_code): self.language_code = language_code def recognize(self, audio): - # workaround for setting api version ( https://github.com/openai/openai-python/pull/491 ) - current_api_version = openai.api_version - openai.api_version = "2020-11-07" - transcription = openai.Audio.transcribe( - "whisper-1", - file=audio, + client = OpenAI(api_key=self.api_key) + transcription = client.audio.transcriptions.create( + model="whisper-1", + # https://github.com/openai/openai-python/tree/main#file-uploads + file=(audio.name, audio.read()), language=self.language_code.replace("-IN", ""), - api_key=self.api_key, - api_base="https://api.openai.com/v1", - api_type="open_ai", - api_version="2020-11-07", # Bug in openai package, this parameter is ignored + # api_version="2020-11-07", ) - openai.api_version = current_api_version return transcription.text diff --git a/ayushma/views/chat.py b/ayushma/views/chat.py index 7a1cceaf..c6f72163 100644 --- a/ayushma/views/chat.py +++ b/ayushma/views/chat.py @@ -1,7 +1,9 @@ +import time + from django.conf import settings from drf_spectacular.utils import extend_schema from rest_framework import filters, status -from rest_framework.decorators import action +from rest_framework.decorators import action, api_view, permission_classes from rest_framework.exceptions import ValidationError from rest_framework.mixins import ( CreateModelMixin, @@ -20,8 +22,10 @@ ChatFeedbackSerializer, ChatSerializer, ConverseSerializer, + SpeechToTextSerializer, ) from ayushma.utils.converse import converse_api +from ayushma.utils.speech_to_text import speech_to_text from utils.views.base import BaseModelViewSet from utils.views.mixins import PartialUpdateModelMixin @@ -42,6 +46,7 @@ class ChatViewSet( "retrieve": ChatDetailSerializer, "list_all": ChatDetailSerializer, "converse": ConverseSerializer, + "speech_to_text": SpeechToTextSerializer, } permission_classes = (IsTempTokenOrAuthenticated,) lookup_field = "external_id" @@ -100,6 +105,43 @@ def list_all(self, *args, **kwarg): serializer = self.get_serializer(queryset, many=True) return Response(serializer.data) + @extend_schema( + tags=("chats",), + ) + @action(detail=True, methods=["post"]) + def speech_to_text(self, *args, **kwarg): + serializer = self.get_serializer(data=self.request.data) + serializer.is_valid() + + project_id = kwarg["project_external_id"] + audio = serializer.validated_data["audio"] + language = serializer.validated_data.get("language", "en") + + stats = {} + try: + stt_engine = Project.objects.get(external_id=project_id).stt_engine + except Project.DoesNotExist: + return Response( + {"error": "Project not found"}, status=status.HTTP_400_BAD_REQUEST + ) + try: + stats["transcript_start_time"] = time.time() + transcript = speech_to_text(stt_engine, audio, language + "-IN") + stats["transcript_end_time"] = time.time() + translated_text = transcript + except Exception as e: + print(f"Failed to transcribe speech with {stt_engine} engine: {e}") + return Response( + { + "error": "Something went wrong in getting transcription, please try again later" + }, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + return Response( + {"transcript": translated_text, "stats": stats}, status=status.HTTP_200_OK + ) + @extend_schema( tags=("chats",), ) diff --git a/compose/local/django/Dockerfile b/compose/local/django/Dockerfile index a4d28829..432fd92c 100644 --- a/compose/local/django/Dockerfile +++ b/compose/local/django/Dockerfile @@ -22,7 +22,7 @@ RUN pip install pipenv # Requirements are installed here to ensure they will be cached. COPY Pipfile Pipfile.lock ./ -RUN pipenv sync --system --categories "packages" +RUN pipenv sync --system --categories "packages dev-packages" # Python 'run' stage FROM python as python-run-stage diff --git a/utils/pagination.py b/utils/pagination.py index 05509fe4..c54def1c 100644 --- a/utils/pagination.py +++ b/utils/pagination.py @@ -13,5 +13,6 @@ def get_paginated_response(self, data): "has_previous": self.offset > 0, "has_next": self.offset + self.limit < self.count, "results": data, + "offset": self.offset, } )