Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: evaluate api #463

Merged
merged 13 commits into from
Dec 12, 2024
6 changes: 5 additions & 1 deletion backend/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ run_dev_server:

run_dev_celery_worker:
@echo "Running celery..."
@rye run celery -A app.celery worker
@rye run celery -A app.celery worker -Q default

run_eval_dev_celery_worker:
@echo "Running evaluation celery..."
@rye run celery -A app.celery worker -Q evaluation --loglevel=debug --pool=solo
81 changes: 81 additions & 0 deletions backend/app/alembic/versions/a54f966436ce_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""evaluation

Revision ID: a54f966436ce
Revises: 27a6723b767a
Create Date: 2024-12-09 16:46:21.077517

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from tidb_vector.sqlalchemy import VectorType
from sqlalchemy.dialects import mysql

# revision identifiers, used by Alembic.
revision = 'a54f966436ce'
down_revision = '27a6723b767a'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('evaluation_datasets',
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column('user_id', sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('evaluation_tasks',
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column('user_id', sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column('dataset_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('evaluation_datasets_items',
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('query', sa.Text(), nullable=True),
sa.Column('reference', sa.Text(), nullable=True),
sa.Column('retrieved_contexts', sa.JSON(), nullable=True),
sa.Column('extra', sa.JSON(), nullable=True),
sa.Column('evaluation_dataset_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['evaluation_dataset_id'], ['evaluation_datasets.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('evaluation_items',
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('chat_engine', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column('status', sa.String(length=32), nullable=False),
sa.Column('query', sa.Text(), nullable=True),
sa.Column('reference', sa.Text(), nullable=True),
sa.Column('response', sa.Text(), nullable=True),
sa.Column('retrieved_contexts', sa.JSON(), nullable=True),
sa.Column('extra', sa.JSON(), nullable=True),
sa.Column('error_msg', sa.Text(), nullable=True),
sa.Column('factual_correctness', sa.Float(), nullable=True),
sa.Column('semantic_similarity', sa.Float(), nullable=True),
sa.Column('evaluation_task_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['evaluation_task_id'], ['evaluation_tasks.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('evaluation_items')
op.drop_table('evaluation_datasets_items')
op.drop_table('evaluation_tasks')
op.drop_table('evaluation_datasets')
# ### end Alembic commands ###
Empty file.
233 changes: 233 additions & 0 deletions backend/app/api/admin_routes/evaluation/evaluation_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
from typing import Optional, List

from fastapi import APIRouter, status, HTTPException, Depends
from fastapi_pagination import Params, Page
from sqlalchemy import func
from sqlmodel import select, case, desc

from app.api.admin_routes.evaluation.models import CreateEvaluationDataset, UpdateEvaluationDataset, \
ModifyEvaluationDatasetItem
from app.file_storage import default_file_storage
from app.models import EvaluationTask, EvaluationItem, Upload, EvaluationStatus, EvaluationDataset, \
EvaluationDatasetItem
from app.api.deps import SessionDep, CurrentSuperuserDep

import pandas as pd
from fastapi_pagination.ext.sqlmodel import paginate

from app.tasks.evaluate import add_evaluation_task
from app.types import MimeTypes

router = APIRouter()


@router.post("/admin/evaluation/dataset")
def create_evaluation_dataset(
evaluation_dataset: CreateEvaluationDataset,
session: SessionDep,
user: CurrentSuperuserDep
) -> EvaluationDataset:
"""
Create a dataset for a given question and chat engine.
This API depends on the /admin/uploads API to upload the evaluation data.
The evaluation data is expected to be a CSV file with the following columns:

- query: The query to evaluate
- reference: The expected response to the query

You can add more columns to the CSV file, and the extra columns will adhere to the results.

Args:
evaluation_dataset.name: The name of the evaluation dataset.
evaluation_dataset.upload_id: The ID of the uploaded CSV file of the evaluation dataset.

Returns:
True if the evaluation dataset is created successfully.
"""
name = evaluation_dataset.name
evaluation_file_id = evaluation_dataset.upload_id

if evaluation_file_id is not None:
# If the evaluation_file_id is provided, validate the uploaded file
upload = session.get(Upload, evaluation_file_id)

if not upload or upload.user_id != user.id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Uploaded file not found",
)

if upload.mime_type != MimeTypes.CSV:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The uploaded file must be a CSV file.",
)

with default_file_storage.open(upload.path) as f:
df = pd.read_csv(f)

# check essential columns
must_have_columns = ["query", "reference"]
if not set(must_have_columns).issubset(df.columns):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The uploaded file must have the following columns: {must_have_columns}",
)

eval_list = df.to_dict(orient='records')
# create evaluation dataset items
evaluation_data_list = [EvaluationDatasetItem(
query=item["query"],
reference=item["reference"],
retrieved_contexts=[], # TODO: implement this after we can retrieve contexts
extra={k: item[k] for k in item if k not in must_have_columns},
) for item in eval_list]

evaluation_task = EvaluationDataset(
Mini256 marked this conversation as resolved.
Show resolved Hide resolved
name=name,
user_id=user.id,
evaluation_data_list=evaluation_data_list,
)

session.add(evaluation_task)
session.commit()

return evaluation_task


@router.delete("/admin/evaluation/dataset/{evaluation_dataset_id}")
def delete_evaluation_dataset(
evaluation_dataset_id: int,
session: SessionDep,
user: CurrentSuperuserDep
) -> bool:
evaluation_dataset = session.get(EvaluationDataset, evaluation_dataset_id)

if not evaluation_dataset or evaluation_dataset.user_id != user.id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Evaluation dataset not found",
)

session.delete(evaluation_dataset)
session.commit()

return True


@router.put("/admin/evaluation/dataset/{evaluation_dataset_id}")
def update_evaluation_dataset(
evaluation_dataset_id: int,
updated_evaluation_dataset: UpdateEvaluationDataset,
session: SessionDep,
user: CurrentSuperuserDep
) -> EvaluationDataset:
evaluation_dataset = session.get(EvaluationDataset, evaluation_dataset_id)

if not evaluation_dataset or evaluation_dataset.user_id != user.id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Evaluation dataset not found",
)

evaluation_dataset.name = updated_evaluation_dataset.name

session.merge(evaluation_dataset)
session.commit()

return evaluation_dataset


@router.get("/admin/evaluation/dataset")
def list_evaluation_dataset(
session: SessionDep,
user: CurrentSuperuserDep,
params: Params = Depends(),
) -> Page[EvaluationDataset]:
stmt = (
select(EvaluationDataset)
.where(EvaluationDataset.user_id == user.id)
.order_by(desc(EvaluationDataset.id))
)
return paginate(session, stmt, params)


@router.post("/admin/evaluation/dataset-item")
def create_evaluation_dataset_item(
evaluation_dataset_item: ModifyEvaluationDatasetItem,
session: SessionDep,
user: CurrentSuperuserDep
) -> EvaluationDatasetItem:
evaluation_task_item = EvaluationDatasetItem(
query=evaluation_dataset_item.query,
reference=evaluation_dataset_item.reference,
retrieved_contexts=evaluation_dataset_item.retrieved_contexts,
extra=evaluation_dataset_item.extra,
evaluation_dataset_id=evaluation_dataset_item.evaluation_dataset_id,
)

session.add(evaluation_task_item)
session.commit()

return evaluation_task_item


@router.delete("/admin/evaluation/dataset-item/{evaluation_dataset_item_id}")
def delete_evaluation_dataset_item(
evaluation_dataset_item_id: int,
session: SessionDep,
user: CurrentSuperuserDep
) -> bool:
evaluation_dataset_item = session.get(EvaluationDatasetItem, evaluation_dataset_item_id)

if not evaluation_dataset_item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Evaluation dataset item not found",
)

session.delete(evaluation_dataset_item)
session.commit()

return True


@router.put("/admin/evaluation/dataset-item/{evaluation_dataset_item_id}")
def update_evaluation_dataset_item(
evaluation_dataset_item_id: int,
updated_evaluation_dataset_item: ModifyEvaluationDatasetItem,
session: SessionDep,
user: CurrentSuperuserDep
) -> EvaluationDatasetItem:
evaluation_dataset_item = session.get(EvaluationDatasetItem, evaluation_dataset_item_id)

if not evaluation_dataset_item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Evaluation dataset item not found",
)
Mini256 marked this conversation as resolved.
Show resolved Hide resolved

evaluation_dataset_item.query = updated_evaluation_dataset_item.query
evaluation_dataset_item.reference = updated_evaluation_dataset_item.reference
evaluation_dataset_item.retrieved_contexts = updated_evaluation_dataset_item.retrieved_contexts
evaluation_dataset_item.extra = updated_evaluation_dataset_item.extra
evaluation_dataset_item.evaluation_dataset_id = updated_evaluation_dataset_item.evaluation_dataset_id

session.commit()

return evaluation_dataset_item


@router.get("/admin/evaluation/dataset/{evaluation_dataset_id}/dataset-item")
Mini256 marked this conversation as resolved.
Show resolved Hide resolved
def list_evaluation_dataset_item(
session: SessionDep,
user: CurrentSuperuserDep,
evaluation_dataset_id: int,
params: Params = Depends(),
) -> Page[EvaluationDatasetItem]:
stmt = (
select(EvaluationDatasetItem)
.where(EvaluationDatasetItem.evaluation_dataset_id == evaluation_dataset_id)
.order_by(EvaluationDatasetItem.id)
)
return paginate(session, stmt, params)
Loading
Loading