Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: evaluate api #463

Merged
merged 13 commits into from
Dec 12, 2024
6 changes: 5 additions & 1 deletion backend/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ run_dev_server:

run_dev_celery_worker:
@echo "Running celery..."
@rye run celery -A app.celery worker
@rye run celery -A app.celery worker -Q default

run_eval_dev_celery_worker:
@echo "Running evaluation celery..."
@rye run celery -A app.celery worker -Q evaluation --loglevel=debug --pool=solo
81 changes: 81 additions & 0 deletions backend/app/alembic/versions/a54f966436ce_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""evaluation

Revision ID: a54f966436ce
Revises: 27a6723b767a
Create Date: 2024-12-09 16:46:21.077517

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from tidb_vector.sqlalchemy import VectorType
from sqlalchemy.dialects import mysql

# revision identifiers, used by Alembic.
revision = 'a54f966436ce'
down_revision = '27a6723b767a'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('evaluation_datasets',
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column('user_id', sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('evaluation_tasks',
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column('user_id', sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column('dataset_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('evaluation_datasets_items',
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('query', sa.Text(), nullable=True),
sa.Column('reference', sa.Text(), nullable=True),
sa.Column('retrieved_contexts', sa.JSON(), nullable=True),
sa.Column('extra', sa.JSON(), nullable=True),
sa.Column('evaluation_dataset_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['evaluation_dataset_id'], ['evaluation_datasets.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('evaluation_task_items',
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('chat_engine', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column('status', sa.String(length=32), nullable=False),
sa.Column('query', sa.Text(), nullable=True),
sa.Column('reference', sa.Text(), nullable=True),
sa.Column('response', sa.Text(), nullable=True),
sa.Column('retrieved_contexts', sa.JSON(), nullable=True),
sa.Column('extra', sa.JSON(), nullable=True),
sa.Column('error_msg', sa.Text(), nullable=True),
sa.Column('factual_correctness', sa.Float(), nullable=True),
sa.Column('semantic_similarity', sa.Float(), nullable=True),
sa.Column('evaluation_task_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['evaluation_task_id'], ['evaluation_tasks.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('evaluation_task_items')
op.drop_table('evaluation_datasets_items')
op.drop_table('evaluation_tasks')
op.drop_table('evaluation_datasets')
# ### end Alembic commands ###
Empty file.
197 changes: 197 additions & 0 deletions backend/app/api/admin_routes/evaluation/evaluation_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import pandas as pd
from fastapi import APIRouter, status, HTTPException, Depends
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlmodel import select, desc

from app.api.admin_routes.evaluation.models import CreateEvaluationDataset, UpdateEvaluationDataset, \
ModifyEvaluationDatasetItem
from app.api.admin_routes.evaluation.tools import must_get, must_get_and_belong
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.file_storage import default_file_storage
from app.models import Upload, EvaluationDataset, EvaluationDatasetItem
from app.types import MimeTypes

router = APIRouter()


@router.post("/admin/evaluation/dataset")
def create_evaluation_dataset(
evaluation_dataset: CreateEvaluationDataset,
session: SessionDep,
user: CurrentSuperuserDep
) -> EvaluationDataset:
"""
Create a dataset for a given question and chat engine.
This API depends on the /admin/uploads API to upload the evaluation data.
The evaluation data is expected to be a CSV file with the following columns:

- query: The query to evaluate
- reference: The expected response to the query

You can add more columns to the CSV file, and the extra columns will adhere to the results.

Args:
evaluation_dataset.name: The name of the evaluation dataset.
evaluation_dataset.upload_id: The ID of the uploaded CSV file of the evaluation dataset.

Returns:
True if the evaluation dataset is created successfully.
"""
name = evaluation_dataset.name
evaluation_file_id = evaluation_dataset.upload_id

if evaluation_file_id is not None:
# If the evaluation_file_id is provided, validate the uploaded file
upload = must_get_and_belong(session, Upload, evaluation_file_id, user.id)

if upload.mime_type != MimeTypes.CSV:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The uploaded file must be a CSV file.",
)

with default_file_storage.open(upload.path) as f:
df = pd.read_csv(f)

# check essential columns
must_have_columns = ["query", "reference"]
if not set(must_have_columns).issubset(df.columns):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The uploaded file must have the following columns: {must_have_columns}",
)

eval_list = df.to_dict(orient='records')
# create evaluation dataset items
evaluation_data_list = [EvaluationDatasetItem(
query=item["query"],
reference=item["reference"],
retrieved_contexts=[], # TODO: implement this after we can retrieve contexts
extra={k: item[k] for k in item if k not in must_have_columns},
) for item in eval_list]

evaluation_dataset = EvaluationDataset(
name=name,
user_id=user.id,
evaluation_data_list=evaluation_data_list,
)

session.add(evaluation_dataset)
session.commit()

return evaluation_dataset


@router.delete("/admin/evaluation/dataset/{evaluation_dataset_id}")
def delete_evaluation_dataset(
evaluation_dataset_id: int,
session: SessionDep,
user: CurrentSuperuserDep
) -> bool:
evaluation_dataset = must_get_and_belong(session, EvaluationDataset, evaluation_dataset_id, user.id)

session.delete(evaluation_dataset)
session.commit()

return True


@router.put("/admin/evaluation/dataset/{evaluation_dataset_id}")
def update_evaluation_dataset(
evaluation_dataset_id: int,
updated_evaluation_dataset: UpdateEvaluationDataset,
session: SessionDep,
user: CurrentSuperuserDep
) -> EvaluationDataset:
evaluation_dataset = must_get_and_belong(session, EvaluationDataset, evaluation_dataset_id, user.id)

evaluation_dataset.name = updated_evaluation_dataset.name

session.merge(evaluation_dataset)
session.commit()

return evaluation_dataset


@router.get("/admin/evaluation/datasets")
def list_evaluation_dataset(
session: SessionDep,
user: CurrentSuperuserDep,
params: Params = Depends(),
) -> Page[EvaluationDataset]:
stmt = (
select(EvaluationDataset)
.where(EvaluationDataset.user_id == user.id)
.order_by(desc(EvaluationDataset.id))
)
return paginate(session, stmt, params)


@router.post("/admin/evaluation/dataset-item")
def create_evaluation_dataset_item(
modify_evaluation_dataset_item: ModifyEvaluationDatasetItem,
session: SessionDep,
user: CurrentSuperuserDep
) -> EvaluationDatasetItem:
evaluation_dataset_item = EvaluationDatasetItem(
query=modify_evaluation_dataset_item.query,
reference=modify_evaluation_dataset_item.reference,
retrieved_contexts=modify_evaluation_dataset_item.retrieved_contexts,
extra=modify_evaluation_dataset_item.extra,
evaluation_dataset_id=modify_evaluation_dataset_item.evaluation_dataset_id,
)

session.add(evaluation_dataset_item)
session.commit()

return evaluation_dataset_item


@router.delete("/admin/evaluation/dataset-item/{evaluation_dataset_item_id}")
def delete_evaluation_dataset_item(
evaluation_dataset_item_id: int,
session: SessionDep,
user: CurrentSuperuserDep
) -> bool:
evaluation_dataset_item = must_get(session, EvaluationDataset, evaluation_dataset_item_id)

session.delete(evaluation_dataset_item)
session.commit()

return True


@router.put("/admin/evaluation/dataset-item/{evaluation_dataset_item_id}")
def update_evaluation_dataset_item(
evaluation_dataset_item_id: int,
updated_evaluation_dataset_item: ModifyEvaluationDatasetItem,
session: SessionDep,
user: CurrentSuperuserDep
) -> EvaluationDatasetItem:
evaluation_dataset_item = must_get(session, EvaluationDatasetItem, evaluation_dataset_item_id)

evaluation_dataset_item.query = updated_evaluation_dataset_item.query
evaluation_dataset_item.reference = updated_evaluation_dataset_item.reference
evaluation_dataset_item.retrieved_contexts = updated_evaluation_dataset_item.retrieved_contexts
evaluation_dataset_item.extra = updated_evaluation_dataset_item.extra
evaluation_dataset_item.evaluation_dataset_id = updated_evaluation_dataset_item.evaluation_dataset_id

session.commit()

return evaluation_dataset_item


@router.get("/admin/evaluation/datasets/{evaluation_dataset_id}/dataset-items")
def list_evaluation_dataset_item(
session: SessionDep,
user: CurrentSuperuserDep,
evaluation_dataset_id: int,
params: Params = Depends(),
) -> Page[EvaluationDatasetItem]:
stmt = (
select(EvaluationDatasetItem)
.where(EvaluationDatasetItem.evaluation_dataset_id == evaluation_dataset_id)
.order_by(EvaluationDatasetItem.id)
)
return paginate(session, stmt, params)
Loading
Loading