Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store Message embedding #540

Merged
merged 13 commits into from
Jan 9, 2023
1 change: 1 addition & 0 deletions ansible/dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
DEBUG_USE_SEED_DATA: "true"
MAX_WORKERS: "1"
RATE_LIMIT: "false"
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
ports:
- 8080:8080

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""added miniLM_embedding column to message

Revision ID: 023548d474f7
Revises: ba61fe17fb6e
Create Date: 2023-01-08 11:06:25.613290

"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "023548d474f7"
down_revision = "ba61fe17fb6e"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message", sa.Column("miniLM_embedding", sa.ARRAY(sa.Float()), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "miniLM_embedding")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""embedding for message now in its own table

Revision ID: 35bdc1a08bb8
Revises: 023548d474f7
Create Date: 2023-01-08 16:03:48.454207

"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "35bdc1a08bb8"
down_revision = "023548d474f7"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"message_embedding",
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("embedding", sa.ARRAY(sa.Float()), nullable=True),
sa.Column("model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
sa.ForeignKeyConstraint(
["message_id"],
["message.id"],
),
sa.PrimaryKeyConstraint("message_id", "model"),
)
op.drop_column("message", "miniLM_embedding")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"message",
sa.Column(
"miniLM_embedding",
postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53)),
autoincrement=False,
nullable=True,
),
)
op.drop_table("message_embedding")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Created date

Revision ID: aac6b2f66006
Revises: 35bdc1a08bb8
Create Date: 2023-01-08 21:28:27.342729

"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "aac6b2f66006"
down_revision = "35bdc1a08bb8"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"message_embedding",
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message_embedding", "created_date")
# ### end Alembic commands ###
9 changes: 2 additions & 7 deletions backend/oasst_backend/api/v1/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
from enum import Enum
from typing import List

from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.schemas.hugging_face import ToxicityClassification
from oasst_backend.utils.hugging_face import HuggingFaceAPI
from oasst_backend.utils.hugging_face import HfUrl, HuggingFaceAPI

router = APIRouter()


class HF_url(str, Enum):
HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta"


@router.get("/text_toxicity")
async def get_text_toxicity(
msg: str,
Expand All @@ -30,7 +25,7 @@ async def get_text_toxicity(
ToxicityClassification: the score of toxicity of the message.
"""

api_url: str = HF_url.HUGGINGFACE_TOXIC_ROBERTA.value
api_url: str = HfUrl.HUGGINGFACE_TOXIC_ROBERTA.value
hugging_face_api = HuggingFaceAPI(api_url)
response = await hugging_face_api.post(msg)

Expand Down
20 changes: 18 additions & 2 deletions backend/oasst_backend/api/v1/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.api.v1.utils import prepare_conversation
from oasst_backend.config import settings
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
Expand Down Expand Up @@ -253,7 +255,7 @@ def tasks_acknowledge_failure(


@router.post("/interaction", response_model=protocol_schema.TaskDone)
def tasks_interaction(
async def tasks_interaction(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
Expand All @@ -274,12 +276,26 @@ def tasks_interaction(
)

# here we store the text reply in the database
pr.store_text_reply(
newMessage = pr.store_text_reply(
text=interaction.text,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
)

if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
try:
hugging_face_api = HuggingFaceAPI(
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}"
)
embedding = await hugging_face_api.post(interaction.text)
pr.insert_message_embedding(
message_id=newMessage.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding
)
except OasstError:
logger.error(
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)

return protocol_schema.TaskDone()
case protocol_schema.MessageRating:
logger.info(
Expand Down
1 change: 1 addition & 0 deletions backend/oasst_backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Settings(BaseSettings):
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
Path(__file__).parent.parent / "test_data/generic/test_generic_data.json"
)
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False

HUGGING_FACE_API_KEY: str = ""

Expand Down
2 changes: 2 additions & 0 deletions backend/oasst_backend/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .api_client import ApiClient
from .journal import Journal, JournalIntegration
from .message import Message
from .message_embedding import MessageEmbedding
from .message_reaction import MessageReaction
from .message_tree_state import MessageTreeState
from .task import Task
Expand All @@ -13,6 +14,7 @@
"User",
"UserStats",
"Message",
"MessageEmbedding",
"MessageReaction",
"MessageTreeState",
"Task",
Expand Down
21 changes: 21 additions & 0 deletions backend/oasst_backend/models/message_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from datetime import datetime
from typing import List, Optional
from uuid import UUID

import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import ARRAY, Field, Float, SQLModel


class MessageEmbedding(SQLModel, table=True):
__tablename__ = "message_embedding"
__table_args__ = (sa.PrimaryKeyConstraint("message_id", "model"),)

message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False))
model: str = Field(max_length=256, nullable=False)
andreaskoepf marked this conversation as resolved.
Show resolved Hide resolved
embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True)

# In the case that the Message Embedding is created afterwards
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
35 changes: 32 additions & 3 deletions backend/oasst_backend/prompt_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import random
from collections import defaultdict
from http import HTTPStatus
from typing import Optional
from typing import List, Optional
from uuid import UUID, uuid4

import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Message, MessageReaction, TextLabels, User
from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, TextLabels, User
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
from oasst_backend.user_repository import UserRepository
Expand Down Expand Up @@ -91,7 +91,12 @@ def insert_message(
self.db.refresh(message)
return message

def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
def store_text_reply(
self,
text: str,
frontend_message_id: str,
user_frontend_message_id: str,
) -> Message:
validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(user_frontend_message_id)

Expand Down Expand Up @@ -224,6 +229,30 @@ def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReact
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)

def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
"""Insert the embedding of a new message in the database.

Args:
message_id (UUID): the identifier of the message we want to save its embedding
model (str): the model used for creating the embedding
embedding (List[float]): the values obtained from the message & model

Raises:
OasstError: if misses some of the before params

Returns:
MessageEmbedding: the instance in the database of the embedding saved for that message
"""

if None in (message_id, model, embedding):
raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR)

message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding)
self.db.add(message_embedding)
self.db.commit()
self.db.refresh(message_embedding)
return message_embedding

def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
if self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
Expand Down
14 changes: 14 additions & 0 deletions backend/oasst_backend/utils/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from enum import Enum
from typing import Any, Dict

import aiohttp
from loguru import logger
from oasst_backend.config import settings
from oasst_shared.exceptions import OasstError, OasstErrorCode


class HfUrl(str, Enum):
HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",)
HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction"


class HfEmbeddingModel(str, Enum):
MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
andreaskoepf marked this conversation as resolved.
Show resolved Hide resolved


class HuggingFaceAPI:
"""Class Object to make post calls to endpoints for inference in models hosted in HuggingFace"""

Expand Down Expand Up @@ -41,6 +52,9 @@ async def post(self, input: str) -> Any:
async with session.post(self.api_url, headers=self.headers, json=payload) as response:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know what happens if the text is longer than the intended text length of the model?

Copy link
Contributor Author

@nil-andreu nil-andreu Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this.
If there is a certain sentence length pre-defined in the model, the extra words will be truncated.

And the arch of our model is the following:

SentenceTransformer(
     (0): Transformer({
            'max_seq_length': 128, 
            'do_lower_case': False
            }) with Transformer model: BertModel 
     (1): Pooling({
           'word_embedding_dimension': 384, 
           'pooling_mode_cls_token': False, 
           'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 
           'pooling_mode_mean_sqrt_len_tokens': False})
)

So would have a sentence length of 128 words. The impact on us will depend then if there is an important amount of messages that have > 128 words.

Related Issue here.

# If we get a bad response
if response.status != 200:

logger.error(response)
logger.info(self.headers)
raise OasstError(
"Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR
)
Expand Down
1 change: 1 addition & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ services:
- DEBUG_SKIP_API_KEY_CHECK=True
- DEBUG_USE_SEED_DATA=True
- MAX_WORKERS=1
- DEBUG_SKIP_EMBEDDING_COMPUTATION=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you also add this to the ansible playbook?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When deployed via ansible the fetching of the embeddings should not be skipped?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When should be skipped? Only in local-run.sh?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's why I asked again... probably also won't be necessary for Frontend development..

depends_on:
db:
condition: service_healthy
Expand Down
1 change: 1 addition & 0 deletions scripts/backend-development/run-local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pushd "$parent_path/../../backend"

export DEBUG_SKIP_API_KEY_CHECK=True
export DEBUG_USE_SEED_DATA=True
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True

uvicorn main:app --reload --port 8080 --host 0.0.0.0

Expand Down