Skip to content

Commit

Permalink
adding tests for Fast API routes
Browse files Browse the repository at this point in the history
  • Loading branch information
paraskuk committed Mar 22, 2024
1 parent 71b2bb9 commit c332fec
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 57 deletions.
34 changes: 24 additions & 10 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ async def login_error(request: Request, message: str):

@app.get("/logout")
async def logout(request: Request):

"""
Function to log out from GitHub
:param request:
Expand Down Expand Up @@ -166,9 +165,9 @@ async def authenticated(request: Request):


@app.post("/save-to-github/{username}/{repository}")
async def save_to_github(repository: str, username:str, request: Request, file: GitHubFile):
async def save_to_github(repository: str, username: str, request: Request, file: GitHubFile):
log.info("Starting save to github route")
#potentially remove the repository parameter and add file.repository
# potentially remove the repository parameter and add file.repository
log.info(f"repo value is {repository} username value is {username}")
if 'auth_token' not in request.state.session:
log.info("auth token not in session, raising exception")
Expand Down Expand Up @@ -250,6 +249,19 @@ async def index(request: Request):
return http_exception_handler(exc)


def create_client_moderation(query_params):
"""
Function to create a client moderation request
:param query_params:
:param user_input: str, user input to be sent to the moderation API
:return: Moderation response
"""
moderation_result = client.moderations.create(
input=query_params.user_input
)
return moderation_result


def create_gpt4_completion(model: str, system_message: str, user_input: str) -> None or Optional[str]:
"""
Function to create a GPT-4 completion request
Expand Down Expand Up @@ -293,9 +305,11 @@ async def ask_gpt4(query_params: QueryModel) -> JSONResponse:
)

# Moderation API to evaluate the query
moderation_result = client.moderations.create(
input=query_params.user_input
)
moderation_result = create_client_moderation(query_params)

# moderation_result = client.moderations.create(
# input=query_params.user_input
# )

if not code_completion:
raise HTTPException(status_code=500, detail="No response from the model for code completion.")
Expand Down Expand Up @@ -364,10 +378,10 @@ async def send_feedback(feedback_data: FeedbackModel):
:param feedback_data: Instance of FeedbackModel
:return: Dictionary with message indicating feedback received
"""
# Log feedback
log.info(f"Received feedback: {feedback_data.feedback} for response ID: {feedback_data.responseId}")

# Here you can add logic to analyze feedback or store it for future improvements
# Note: GPT-4 doesn't have a direct mechanism to improve based on this feedback
log.info(f"Received feedback: {feedback_data.feedback} for response ID: {feedback_data.responseId}")

return {"message": "Feedback received"}



2 changes: 1 addition & 1 deletion models/query_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class QueryModel(BaseModel):
Data class Model for query input sets max length of user_input to 2000
"""
user_input: str = Field(min_length=1, max_length=2000)
model: str = "gpt-4-0125-preview" # Default model is gpt-4
model: str = "gpt-4" # Default model is gpt-4


class FeedbackModel(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pyflakes==3.0.1
Pygments==2.15.1
pyparsing==3.1.0
pytest==7.4.0
pytest-asyncio==0.23.5.post1
pytest-cov==4.1.0
PyYAML==6.0
redis==5.0.1
Expand Down
10 changes: 9 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,12 @@ def test_client() -> TestClient:
:return: a Test client object.
"""
with TestClient(app) as client:
yield client
yield client

@pytest.fixture
def test_app():
"""
Fixture to return the FastAPI app instance
:return: app
"""
return app
154 changes: 109 additions & 45 deletions tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,129 @@
import json
import uuid

import pytest
from fastapi.testclient import TestClient
from fastapi import HTTPException
from typing import Union, Any
from app import ask_gpt4, app, http_exception_handler
from models.query_model import QueryModel
from exceptions.GPTException import GPTException
from models.query_model import QueryModel, RedisSessionMiddleware
from unittest.mock import patch, Mock
from fastapi.responses import RedirectResponse
from starlette.requests import Request
from unittest.mock import patch, AsyncMock
import pytest
from starlette.responses import RedirectResponse, JSONResponse, Response
from httpx import AsyncClient

client = TestClient(app)

from unittest.mock import patch, MagicMock


class MockGPTResponse:
def __init__(self, code_completion, user_level_estimation, sentiment_estimation):
self.code_completion = code_completion
self.user_level_estimation = user_level_estimation
self.sentiment_estimation = sentiment_estimation


class MockModerationResponse:
def __init__(self, flagged=False):
self._flagged = flagged

def json(self):
return {
"flagged": self._flagged
}


@pytest.mark.asyncio
async def test_login_via_github():
"""
Test the login via GitHub endpoint with AsyncClient
:return: boolean, True if the test passes with 302 code for redirect
"""
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.get("/login/github")
assert response.status_code == 302
assert "location" in response.headers
assert response.headers["location"].startswith("https://github.com/login/oauth/authorize")


@pytest.mark.asyncio
async def test_gpt_exception_handler():
"""
Function to test the GPT exception handler
:return: boolean, True if the test passes with 400, 404 code for bad request
"""
client = TestClient(app)
response = client.get("/trigger-exception")

assert response.status_code == 400 or response.status_code == 404


@pytest.mark.asyncio
@pytest.mark.parametrize("query_params, model, expected_output", [
(
{
"user_input": "What is the capital of France?Please answer with one word only and dont add dot at the end"},
"text-davinci-003",
"Paris"
),
(
{"user_input": "Which is the capital of UK? Please answer with one word only and dont add dot at the end"},
"text-davinci-003",
"London"
),
# Add more test cases here
])
async def test_ask_gpt4(query_params, model, expected_output):
response = client.post(
"/ask_gpt4/",
json={"user_input": query_params["user_input"], "model": model},
)
async def test_authorize():
"""
Test the authorize endpoint with AsyncClient
:return: boolean, True if the test passes with 307 or 302 code for redirect
"""
with patch('app.oauth.github.authorize_access_token', new_callable=AsyncMock) as mock_oauth:
mock_oauth.return_value = {'access_token': 'test_token'}
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.get('/auth/github/callback?code=testcode')

assert response.status_code == 307 or response.status_code == 302


def test_index_page():
"""
Function to test the index page where the user can input a query
:return: boolean True if the test passes
"""
response = client.get("/")
assert response.status_code == 200
json_response = response.json()
assert response.headers["content-type"] == "text/html; charset=utf-8"

# Check if the response contains a valid answer
assert "response" in json_response or "error" in json_response

# If there's an error, check if it's a known error
if "error" in json_response:
assert json_response["error"] in [
"ChatGPT response does not contain text attribute.",
# Add other known errors here
]
else:
assert json_response["response"] == expected_output
def test_login_error():
"""
Function to test the login error message
:return: boolean to check if the test passes
"""
test_message = "Test error message"
response = client.get(f"/login-error?message={test_message}")
assert response.status_code == 400
assert response.headers["content-type"] == "application/json"
response_data = response.json()
assert response_data == {"error": test_message}


def test_logout_redirects_to_index():
"""
Function to test the logout endpoint
:return: boolean, True if the test passes with 307 code for redirect after logout to index
"""
response = client.get("/logout", allow_redirects=False)

assert response.status_code == 307
assert response.headers["location"] == "/"


@pytest.mark.asyncio
@pytest.mark.parametrize("status_code, detail, expected_result", [
(404, "Not Found", {"detail": "Not Found", "status_code": 404}),
(500, "Internal Server Error", {"detail": "Internal Server Error", "status_code": 500}),
(401, "Unauthorized", {"detail": "Unauthorized", "status_code": 401}),
])
async def test_http_exception_handler(status_code: int, detail: Union[str, dict], expected_result: Any) -> None:
async def test_redis_session_middleware():
"""
Function to test http exception handler
:param status_code: int ,status code e.g. 400, 404 etc.
:param detail: str or Dict , detail message
:param expected_result:
:return: None
Function to test the redis session middleware
:return: boolean, True if the test passes with 200 code
"""
exc = HTTPException(status_code=status_code, detail=detail)
result = await http_exception_handler(exc)
assert result == expected_result
with patch('models.query_model.redis_client.get', MagicMock(return_value=None)), \
patch('models.query_model.redis_client.set', MagicMock()) as mock_set, \
patch('uuid.uuid4', MagicMock(return_value=uuid.UUID('12345678-4567-5678-9888-567812345678'))):
response = client.get("/")

assert response.status_code == 200

assert 'session_id' in response.cookies

assert mock_set.called

0 comments on commit c332fec

Please sign in to comment.