Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: evaluate api #463

Merged
merged 13 commits into from
Dec 12, 2024
6 changes: 5 additions & 1 deletion backend/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ run_dev_server:

run_dev_celery_worker:
@echo "Running celery..."
@rye run celery -A app.celery worker
@rye run celery -A app.celery worker -Q default

run_eval_dev_celery_worker:
@echo "Running evaluation celery..."
@rye run celery -A app.celery worker -Q evaluation --loglevel=debug --pool=solo
59 changes: 59 additions & 0 deletions backend/app/alembic/versions/83e81f4c63d6_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""evaluation

Revision ID: 83e81f4c63d6
Revises: 27a6723b767a
Create Date: 2024-12-04 14:06:42.926516

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from tidb_vector.sqlalchemy import VectorType
from sqlalchemy.dialects import mysql

# revision identifiers, used by Alembic.
revision = '83e81f4c63d6'
down_revision = '27a6723b767a'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('evaluation_tasks',
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column('user_id', sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column('upload_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['upload_id'], ['uploads.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('evaluation_items',
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('chat_engine', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column('status', sa.String(length=32), nullable=False),
sa.Column('query', sa.Text(), nullable=True),
sa.Column('reference', sa.Text(), nullable=True),
sa.Column('response', sa.Text(), nullable=True),
sa.Column('retrieved_contexts', sa.JSON(), nullable=True),
sa.Column('extra', sa.JSON(), nullable=True),
sa.Column('error_msg', sa.Text(), nullable=True),
sa.Column('factual_correctness', sa.Float(), nullable=True),
sa.Column('semantic_similarity', sa.Float(), nullable=True),
sa.Column('evaluation_task_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['evaluation_task_id'], ['evaluation_tasks.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('evaluation_items')
op.drop_table('evaluation_tasks')
# ### end Alembic commands ###
200 changes: 200 additions & 0 deletions backend/app/api/admin_routes/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from typing import Optional, List

from fastapi import APIRouter, status, HTTPException, Depends
from fastapi_pagination import Params, Page
from sqlalchemy import func
from sqlmodel import select, case, desc

from app.api.admin_routes.models import CreateEvaluationTask, EvaluationTaskSummary
from app.file_storage import default_file_storage
from app.models import EvaluationTask, EvaluationItem, Upload, EvaluationStatus
from app.api.deps import SessionDep, CurrentSuperuserDep

import pandas as pd
from fastapi_pagination.ext.sqlmodel import paginate

from app.tasks.evaluate import add_evaluation_task
from app.types import MimeTypes

router = APIRouter()


@router.post("/admin/evaluation/task")
def create_evaluation_task(
evaluation_task: CreateEvaluationTask,
session: SessionDep,
user: CurrentSuperuserDep
) -> Optional[EvaluationTask]:
"""
Create an evaluation task for a given question and chat engine.
This API depends on the /admin/uploads API to upload the evaluation data.
The evaluation data is expected to be a CSV file with the following columns:

- query: The query to evaluate
- reference: The expected response to the query

You can add more columns to the CSV file, and the extra columns will adhere to the results.

Args:
evaluation_task.name: The name of the evaluation task.
evaluation_task.upload_id: The ID of the uploaded evaluation CSV file.
evaluation_task.chat_engine: The chat engine to evaluate the queries against. Default is "default".
evaluation_task.run_size: The number of queries to evaluate. Default is None, which means all queries in the CSV file.

Returns:
True if the evaluation task is created successfully.
"""

name = evaluation_task.name
evaluation_file_id = evaluation_task.upload_id
chat_engine = evaluation_task.chat_engine
run_size = evaluation_task.run_size

upload = session.get(Upload, evaluation_file_id)

# csv file handler
if not upload or upload.user_id != user.id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Uploaded file not found",
)

if upload.mime_type != MimeTypes.CSV:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The uploaded file must be a CSV file.",
)

# retrieve the csv file and check the columns
with default_file_storage.open(upload.path) as f:
df = pd.read_csv(f)

# check essential columns
must_have_columns = ["query", "reference"]
if not set(must_have_columns).issubset(df.columns):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The uploaded file must have the following columns: {must_have_columns}",
)

eval_list = df.to_dict(orient='records')
if run_size is not None and run_size < len(eval_list):
eval_list = eval_list[:run_size]

evaluation_task = EvaluationTask(
name=name,
user_id=user.id,
upload_id=evaluation_file_id,
)

# create evaluation items
evaluation_items = [EvaluationItem(
status=EvaluationStatus.NOT_START,
chat_engine=chat_engine,
query=item["query"],
reference=item["reference"],
extra={k: item[k] for k in item if k not in must_have_columns},
) for item in eval_list]

evaluation_task.evaluation_items = evaluation_items

session.add(evaluation_task)
session.commit()

add_evaluation_task.delay(evaluation_task.id)

return evaluation_task


@router.get("/admin/evaluation/task-summary/{evaluation_task_id}")
def get_evaluation_task_summary(
evaluation_task_id: int,
session: SessionDep,
user: CurrentSuperuserDep
) -> EvaluationTaskSummary:
task = session.exec(select(EvaluationTask).where(EvaluationTask.id == evaluation_task_id)).first()
if not task:
raise HTTPException(status_code=404, detail="EvaluationTask not found")

if task.user_id != user.id:
raise HTTPException(status_code=403, detail="Access denied")

status_counts = (
session.query(
func.count(case((EvaluationItem.status == EvaluationStatus.NOT_START, 1), else_=None)).label("not_start"),
func.count(case((EvaluationItem.status == EvaluationStatus.EVALUATING, 1), else_=None)).label(
"evaluating"),
func.count(case((EvaluationItem.status == EvaluationStatus.DONE, 1), else_=None)).label("done"),
func.count(case((EvaluationItem.status == EvaluationStatus.ERROR, 1), else_=None)).label("error"),
)
.filter(EvaluationItem.evaluation_task_id == evaluation_task_id)
.one()
)

stats = {}
if status_counts.not_start == 0 and status_counts.evaluating == 0:
stats = (
session.query(
func.avg(EvaluationItem.factual_correctness).label('avg_factual_correctness'),
func.avg(EvaluationItem.semantic_similarity).label('avg_semantic_similarity'),
func.min(EvaluationItem.factual_correctness).label('min_factual_correctness'),
func.min(EvaluationItem.semantic_similarity).label('min_semantic_similarity'),
func.max(EvaluationItem.factual_correctness).label('max_factual_correctness'),
func.max(EvaluationItem.semantic_similarity).label('max_semantic_similarity'),
func.stddev(EvaluationItem.factual_correctness).label('std_factual_correctness'),
func.stddev(EvaluationItem.semantic_similarity).label('std_semantic_similarity'),
)
.filter(
EvaluationItem.evaluation_task_id == evaluation_task_id,
EvaluationItem.status == EvaluationStatus.DONE,
EvaluationItem.factual_correctness.isnot(None),
EvaluationItem.semantic_similarity.isnot(None),
)
.one()
)

return EvaluationTaskSummary(
task=task,
not_start=status_counts.not_start,
succeed=status_counts.done,
errored=status_counts.error,
progressing=status_counts.evaluating,
avg_factual_correctness=stats.avg_factual_correctness,
avg_semantic_similarity=stats.avg_semantic_similarity,
min_factual_correctness=stats.min_factual_correctness,
min_semantic_similarity=stats.min_semantic_similarity,
max_factual_correctness=stats.max_factual_correctness,
max_semantic_similarity=stats.max_semantic_similarity,
std_factual_correctness=stats.std_factual_correctness,
std_semantic_similarity=stats.std_semantic_similarity,
)


@router.get("/admin/evaluation/task")
def list_evaluation_task(
session: SessionDep,
user: CurrentSuperuserDep,
params: Params = Depends(),
) -> Page[EvaluationTask]:
stmt = (
select(EvaluationTask)
.where(EvaluationTask.user_id == user.id)
.order_by(desc(EvaluationTask.id))
)
return paginate(session, stmt, params)


@router.get("/admin/evaluation/all-items/{evaluation_task_id}")
def list_evaluation_task(
evaluation_task_id: int,
session: SessionDep,
user: CurrentSuperuserDep,
) -> List[EvaluationItem]:
task = session.exec(select(EvaluationTask).where(EvaluationTask.id == evaluation_task_id)).first()
if not task:
raise HTTPException(status_code=404, detail="EvaluationTask not found")

if task.user_id != user.id:
raise HTTPException(status_code=403, detail="Access denied")

return task.evaluation_items
27 changes: 26 additions & 1 deletion backend/app/api/admin_routes/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Optional
from uuid import UUID
from pydantic import BaseModel

from app.api.admin_routes.embedding_model.models import EmbeddingModelItem
from app.models import EvaluationTask
Mini256 marked this conversation as resolved.
Show resolved Hide resolved
from app.types import LLMProvider


Expand Down Expand Up @@ -34,4 +36,27 @@ class DataSourceDescriptor(BaseModel):
class ChatEngineDescriptor(BaseModel):
id: int
name: str
is_default: bool
is_default: bool


class CreateEvaluationTask(BaseModel):
name: str
upload_id: int
chat_engine: str = "default"
run_size: Optional[int] = None


class EvaluationTaskSummary(BaseModel):
task: EvaluationTask
not_start: int
succeed: int
errored: int
progressing: int
avg_factual_correctness: Optional[float]
avg_semantic_similarity: Optional[float]
min_factual_correctness: Optional[float]
min_semantic_similarity: Optional[float]
max_factual_correctness: Optional[float]
max_semantic_similarity: Optional[float]
std_factual_correctness: Optional[float]
std_semantic_similarity: Optional[float]
1 change: 1 addition & 0 deletions backend/app/api/admin_routes/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
".docx": MimeTypes.DOCX,
".pptx": MimeTypes.PPTX,
".xlsx": MimeTypes.XLSX,
".csv": MimeTypes.CSV,
}


Expand Down
2 changes: 2 additions & 0 deletions backend/app/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
stats as admin_stats,
semantic_cache as admin_semantic_cache,
langfuse as admin_langfuse,
evaluation as admin_evaluation,
)
from app.auth.users import auth_backend, fastapi_users
from app.api.deps import current_superuser
Expand All @@ -51,6 +52,7 @@
api_router.include_router(admin_retrieve.router, tags=["admin/retrieve"])
api_router.include_router(admin_stats.router, tags=["admin/stats"])
api_router.include_router(admin_semantic_cache.router, tags=["admin/semantic_cache"])
api_router.include_router(admin_evaluation.router, tags=["admin/evaluation"])

api_router.include_router(
fastapi_users.get_auth_router(auth_backend), prefix="/auth", tags=["auth"]
Expand Down
4 changes: 4 additions & 0 deletions backend/app/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
app.conf.update(
task_acks_late=True,
task_reject_on_worker_lost=True,
task_routes=[
{"app.tasks.evaluate.*": {"queue": "evaluation"}},
{"*": {"queue": "default"}}
]
)

app.autodiscover_tasks(["app"])
2 changes: 2 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def _validate_sentry_sample_rate(self) -> Self:
EMBEDDING_DIMS: int = 1536
EMBEDDING_MAX_TOKENS: int = 8191

EVALUATION_OPENAI_API_KEY: str | None = None

@computed_field # type: ignore[misc]
@property
def SQLALCHEMY_DATABASE_URI(self) -> MySQLDsn:
Expand Down
1 change: 1 addition & 0 deletions backend/app/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from .embed_model import EmbeddingModel
from .reranker_model import RerankerModel, AdminRerankerModel
from .recommend_question import RecommendQuestion
from .evaluation_task import EvaluationTask, EvaluationItem, EvaluationStatus
Loading
Loading