-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
144 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,129 @@ | ||
import json | ||
import uuid | ||
|
||
import pytest | ||
from fastapi.testclient import TestClient | ||
from fastapi import HTTPException | ||
from typing import Union, Any | ||
from app import ask_gpt4, app, http_exception_handler | ||
from models.query_model import QueryModel | ||
from exceptions.GPTException import GPTException | ||
from models.query_model import QueryModel, RedisSessionMiddleware | ||
from unittest.mock import patch, Mock | ||
from fastapi.responses import RedirectResponse | ||
from starlette.requests import Request | ||
from unittest.mock import patch, AsyncMock | ||
import pytest | ||
from starlette.responses import RedirectResponse, JSONResponse, Response | ||
from httpx import AsyncClient | ||
|
||
client = TestClient(app) | ||
|
||
from unittest.mock import patch, MagicMock | ||
|
||
|
||
class MockGPTResponse: | ||
def __init__(self, code_completion, user_level_estimation, sentiment_estimation): | ||
self.code_completion = code_completion | ||
self.user_level_estimation = user_level_estimation | ||
self.sentiment_estimation = sentiment_estimation | ||
|
||
|
||
class MockModerationResponse: | ||
def __init__(self, flagged=False): | ||
self._flagged = flagged | ||
|
||
def json(self): | ||
return { | ||
"flagged": self._flagged | ||
} | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_login_via_github(): | ||
""" | ||
Test the login via GitHub endpoint with AsyncClient | ||
:return: boolean, True if the test passes with 302 code for redirect | ||
""" | ||
async with AsyncClient(app=app, base_url="http://test") as ac: | ||
response = await ac.get("/login/github") | ||
assert response.status_code == 302 | ||
assert "location" in response.headers | ||
assert response.headers["location"].startswith("https://github.com/login/oauth/authorize") | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_gpt_exception_handler(): | ||
""" | ||
Function to test the GPT exception handler | ||
:return: boolean, True if the test passes with 400, 404 code for bad request | ||
""" | ||
client = TestClient(app) | ||
response = client.get("/trigger-exception") | ||
|
||
assert response.status_code == 400 or response.status_code == 404 | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("query_params, model, expected_output", [ | ||
( | ||
{ | ||
"user_input": "What is the capital of France?Please answer with one word only and dont add dot at the end"}, | ||
"text-davinci-003", | ||
"Paris" | ||
), | ||
( | ||
{"user_input": "Which is the capital of UK? Please answer with one word only and dont add dot at the end"}, | ||
"text-davinci-003", | ||
"London" | ||
), | ||
# Add more test cases here | ||
]) | ||
async def test_ask_gpt4(query_params, model, expected_output): | ||
response = client.post( | ||
"/ask_gpt4/", | ||
json={"user_input": query_params["user_input"], "model": model}, | ||
) | ||
async def test_authorize(): | ||
""" | ||
Test the authorize endpoint with AsyncClient | ||
:return: boolean, True if the test passes with 307 or 302 code for redirect | ||
""" | ||
with patch('app.oauth.github.authorize_access_token', new_callable=AsyncMock) as mock_oauth: | ||
mock_oauth.return_value = {'access_token': 'test_token'} | ||
async with AsyncClient(app=app, base_url="http://test") as ac: | ||
response = await ac.get('/auth/github/callback?code=testcode') | ||
|
||
assert response.status_code == 307 or response.status_code == 302 | ||
|
||
|
||
def test_index_page(): | ||
""" | ||
Function to test the index page where the user can input a query | ||
:return: boolean True if the test passes | ||
""" | ||
response = client.get("/") | ||
assert response.status_code == 200 | ||
json_response = response.json() | ||
assert response.headers["content-type"] == "text/html; charset=utf-8" | ||
|
||
# Check if the response contains a valid answer | ||
assert "response" in json_response or "error" in json_response | ||
|
||
# If there's an error, check if it's a known error | ||
if "error" in json_response: | ||
assert json_response["error"] in [ | ||
"ChatGPT response does not contain text attribute.", | ||
# Add other known errors here | ||
] | ||
else: | ||
assert json_response["response"] == expected_output | ||
def test_login_error(): | ||
""" | ||
Function to test the login error message | ||
:return: boolean to check if the test passes | ||
""" | ||
test_message = "Test error message" | ||
response = client.get(f"/login-error?message={test_message}") | ||
assert response.status_code == 400 | ||
assert response.headers["content-type"] == "application/json" | ||
response_data = response.json() | ||
assert response_data == {"error": test_message} | ||
|
||
|
||
def test_logout_redirects_to_index(): | ||
""" | ||
Function to test the logout endpoint | ||
:return: boolean, True if the test passes with 307 code for redirect after logout to index | ||
""" | ||
response = client.get("/logout", allow_redirects=False) | ||
|
||
assert response.status_code == 307 | ||
assert response.headers["location"] == "/" | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("status_code, detail, expected_result", [ | ||
(404, "Not Found", {"detail": "Not Found", "status_code": 404}), | ||
(500, "Internal Server Error", {"detail": "Internal Server Error", "status_code": 500}), | ||
(401, "Unauthorized", {"detail": "Unauthorized", "status_code": 401}), | ||
]) | ||
async def test_http_exception_handler(status_code: int, detail: Union[str, dict], expected_result: Any) -> None: | ||
async def test_redis_session_middleware(): | ||
""" | ||
Function to test http exception handler | ||
:param status_code: int ,status code e.g. 400, 404 etc. | ||
:param detail: str or Dict , detail message | ||
:param expected_result: | ||
:return: None | ||
Function to test the redis session middleware | ||
:return: boolean, True if the test passes with 200 code | ||
""" | ||
exc = HTTPException(status_code=status_code, detail=detail) | ||
result = await http_exception_handler(exc) | ||
assert result == expected_result | ||
with patch('models.query_model.redis_client.get', MagicMock(return_value=None)), \ | ||
patch('models.query_model.redis_client.set', MagicMock()) as mock_set, \ | ||
patch('uuid.uuid4', MagicMock(return_value=uuid.UUID('12345678-4567-5678-9888-567812345678'))): | ||
response = client.get("/") | ||
|
||
assert response.status_code == 200 | ||
|
||
assert 'session_id' in response.cookies | ||
|
||
assert mock_set.called |