Skip to content

Commit

Permalink
Feat: refactor settings.py add LLM_ENGINGES for simple usr in DeepDiv…
Browse files Browse the repository at this point in the history
…eService
  • Loading branch information
alexiusstrauss committed Nov 27, 2023
1 parent cdd6605 commit 61db110
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 20 deletions.
4 changes: 3 additions & 1 deletion backend/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

audio_files/
audio_files/

.env
6 changes: 2 additions & 4 deletions backend/src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from .services.exceptions import FileFormatException, NotFoundException
from .services.models import UploadResponse
from .services.services import DeepDive
from .settings.development import CORS_CONFIG, OPEN_AI_TOKEN
from .summarization.engines import LangChain
from .settings.development import CORS_CONFIG, LLM_ENGINE

app = FastAPI()

Expand All @@ -29,8 +28,7 @@ async def healthcheck():

@app.post("/process-audio/", response_model=UploadResponse)
async def create_upload_file(request: Request, audio_file: UploadFile = File(..., description="arquivo .mp3 ou .wav")):
lang_chain = LangChain(api_key=OPEN_AI_TOKEN)
service = DeepDive(llm_engine=lang_chain)
service = DeepDive(llm_engine=LLM_ENGINE.get('LangChain'))
service.validate_api_token()
response = service.upload_audio(audio_file)

Expand Down
3 changes: 3 additions & 0 deletions backend/src/env.exemple
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Configurações de API
OPEN_AI_TOKEN=sk-change-me
TENSORFLOW_MODEL_NAME="t5-small"
8 changes: 0 additions & 8 deletions backend/src/settings.py

This file was deleted.

15 changes: 13 additions & 2 deletions backend/src/settings/development.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# Settings development mode
import os

from dotenv import load_dotenv
from src.summarization.engines import LangChain, TensorFlow

# Carregar variáveis de ambiente do arquivo .env
load_dotenv()

CORS_CONFIG = {
"allow_origins": ["*"],
Expand All @@ -7,5 +13,10 @@
"allow_headers": ["*"],
}

OPEN_AI_TOKEN = os.getenv("OPEN_AI_TOKEN")
TENSORFLOW_MODEL_NAME = os.getenv("TENSORFLOW_MODEL_NAME")

OPEN_AI_TOKEN = "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"
LLM_ENGINE = {
"LangChain": LangChain(api_key=OPEN_AI_TOKEN),
"Tensorflow": TensorFlow(model_name=TENSORFLOW_MODEL_NAME),
}
24 changes: 19 additions & 5 deletions backend/src/summarization/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain.llms.openai import OpenAI
from src.services.exceptions import ApiKeyException
from src.summarization.interfaces import Summarization
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM


class LangChain(Summarization):
Expand Down Expand Up @@ -45,9 +46,22 @@ def token_is_valid(self):
raise ApiKeyException()


class TensorFlowStrategy(Summarization):
def __init__(self):
pass
class TensorFlow(Summarization):
def __init__(self, model_name="t5-small"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)

def summarize(self, text: str) -> str:
return "Texto sumarizado com TensorFlow"
def summarize(self, text: str, max_length=650):
# Preparar a entrada para o modelo
inputs = self.tokenizer.encode(
f"summarize: {text}", return_tensors="tf", max_length=1500, truncation=True
)
# Gerar a saída do modelo
summary_ids = self.model.generate(
inputs,
max_length=max_length,
length_penalty=4.0,
num_beams=4,
early_stopping=True,
)
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)

0 comments on commit 61db110

Please sign in to comment.