From 589ad33426fb90b811f17f64bfa38789a2b0377f Mon Sep 17 00:00:00 2001 From: Vineeth Voruganti <13438633+VVoruganti@users.noreply.github.com> Date: Sat, 14 Sep 2024 16:59:36 -0400 Subject: [PATCH] fix(tests) Clean up test logic --- pyproject.toml | 6 +++--- src/routers/messages.py | 16 ++++++++++------ tests/conftest.py | 4 ++-- tests/routes/test_messages.py | 10 +++++----- tests/routes/test_metamessages.py | 12 ++++++------ uv.lock | 20 ++++++++++++-------- 6 files changed, 38 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index af4ec4e..e6f1893 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "honcho" -version = "0.0.11" +version = "0.0.12" description = "Honcho Server" authors = [ {name = "Plastic Labs", email = "hello@plasticlabs.ai"}, @@ -26,8 +26,8 @@ dependencies = [ "mirascope>=0.18.0", "openai>=1.43.0", ] -[project.optional-dependencies] -test = [ +[tool.uv] +dev-dependencies = [ "pytest>=8.2.2", "sqlalchemy-utils>=0.41.2", "pytest-asyncio>=0.23.7", diff --git a/src/routers/messages.py b/src/routers/messages.py index 28bb76d..95ee2b6 100644 --- a/src/routers/messages.py +++ b/src/routers/messages.py @@ -28,12 +28,16 @@ async def enqueue(payload: dict): session_id=payload["session_id"], ) # Check if metadata has a "deriver" key - deriver_disabled = session.h_metadata.get("deriver_disabled") - if deriver_disabled is not None and deriver_disabled is not False: - print("=====================") - print(f"Deriver is not enabled on session {payload['session_id']}") - print("=====================") - # If deriver is not enabled, do not enqueue + if session is not None: + deriver_disabled = session.h_metadata.get("deriver_disabled") + if deriver_disabled is not None and deriver_disabled is not False: + print("=====================") + print(f"Deriver is not enabled on session {payload['session_id']}") + print("=====================") + # If deriver is not enabled, do not enqueue + return + else: + # Session doesn't exist return return try: processed_payload = { diff --git a/tests/conftest.py b/tests/conftest.py index 32ce050..4288afe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -125,12 +125,12 @@ async def override_get_db(): async def sample_data(db_session): """Helper function to create test data""" # Create test app - test_app = models.App(name=str(uuid.uuid4()), metadata={}) + test_app = models.App(name=str(uuid.uuid4())) db_session.add(test_app) await db_session.flush() # Create test user - test_user = models.User(name=str(uuid.uuid4()), app_id=test_app.id, metadata={}) + test_user = models.User(name=str(uuid.uuid4()), app_id=test_app.id) db_session.add(test_user) await db_session.flush() diff --git a/tests/routes/test_messages.py b/tests/routes/test_messages.py index 7028aaf..000207d 100644 --- a/tests/routes/test_messages.py +++ b/tests/routes/test_messages.py @@ -7,7 +7,7 @@ async def test_create_message(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session - test_session = models.Session(user_id=test_user.id, metadata={}) + test_session = models.Session(user_id=test_user.id) db_session.add(test_session) await db_session.commit() @@ -31,11 +31,11 @@ async def test_create_message(client, db_session, sample_data): async def test_get_messages(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session and message - test_session = models.Session(user_id=test_user.id, metadata={}) + test_session = models.Session(user_id=test_user.id) db_session.add(test_session) await db_session.commit() test_message = models.Message( - session_id=test_session.id, content="Test message", is_user=True, metadata={} + session_id=test_session.id, content="Test message", is_user=True ) db_session.add(test_message) await db_session.commit() @@ -56,11 +56,11 @@ async def test_get_messages(client, db_session, sample_data): async def test_update_message(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session and message - test_session = models.Session(user_id=test_user.id, metadata={}) + test_session = models.Session(user_id=test_user.id) db_session.add(test_session) await db_session.commit() test_message = models.Message( - session_id=test_session.id, content="Test message", is_user=True, metadata={} + session_id=test_session.id, content="Test message", is_user=True ) db_session.add(test_message) await db_session.commit() diff --git a/tests/routes/test_metamessages.py b/tests/routes/test_metamessages.py index d2c81a7..2d7691b 100644 --- a/tests/routes/test_metamessages.py +++ b/tests/routes/test_metamessages.py @@ -7,11 +7,11 @@ async def test_create_metamessage(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session - test_session = models.Session(user_id=test_user.id, metadata={}) + test_session = models.Session(user_id=test_user.id) db_session.add(test_session) await db_session.commit() test_message = models.Message( - session_id=test_session.id, content="Test message", is_user=True, metadata={} + session_id=test_session.id, content="Test message", is_user=True ) db_session.add(test_message) await db_session.commit() @@ -37,11 +37,11 @@ async def test_create_metamessage(client, db_session, sample_data): async def test_get_metamessage(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session - test_session = models.Session(user_id=test_user.id, metadata={}) + test_session = models.Session(user_id=test_user.id) db_session.add(test_session) await db_session.commit() test_message = models.Message( - session_id=test_session.id, content="Test message", is_user=True, metadata={} + session_id=test_session.id, content="Test message", is_user=True ) db_session.add(test_message) await db_session.commit() @@ -69,11 +69,11 @@ async def test_get_metamessage(client, db_session, sample_data): async def test_update_metamessage(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session - test_session = models.Session(user_id=test_user.id, metadata={}) + test_session = models.Session(user_id=test_user.id) db_session.add(test_session) await db_session.commit() test_message = models.Message( - session_id=test_session.id, content="Test message", is_user=True, metadata={} + session_id=test_session.id, content="Test message", is_user=True ) db_session.add(test_message) await db_session.commit() diff --git a/uv.lock b/uv.lock index 58c86cc..6c3ecb0 100644 --- a/uv.lock +++ b/uv.lock @@ -455,7 +455,7 @@ wheels = [ [[package]] name = "honcho" -version = "0.0.11" +version = "0.0.12" source = { virtual = "." } dependencies = [ { name = "fastapi", extra = ["standard"] }, @@ -477,8 +477,8 @@ dependencies = [ { name = "sqlalchemy" }, ] -[package.optional-dependencies] -test = [ +[package.dev-dependencies] +dev = [ { name = "coverage" }, { name = "interrogate" }, { name = "pytest" }, @@ -488,12 +488,10 @@ test = [ [package.metadata] requires-dist = [ - { name = "coverage", marker = "extra == 'test'", specifier = ">=7.6.0" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.111.0" }, { name = "fastapi-pagination", specifier = ">=0.12.24" }, { name = "greenlet", specifier = ">=3.0.3" }, { name = "httpx", specifier = ">=0.27.0" }, - { name = "interrogate", marker = "extra == 'test'", specifier = ">=1.7.0" }, { name = "mirascope", specifier = ">=0.18.0" }, { name = "openai", specifier = ">=1.43.0" }, { name = "opentelemetry-exporter-otlp", specifier = ">=1.24.0" }, @@ -503,13 +501,19 @@ requires-dist = [ { name = "opentelemetry-sdk", specifier = ">=1.24.0" }, { name = "pgvector", specifier = ">=0.2.5" }, { name = "psycopg", extras = ["binary"], specifier = ">=3.1.19" }, - { name = "pytest", marker = "extra == 'test'", specifier = ">=8.2.2" }, - { name = "pytest-asyncio", marker = "extra == 'test'", specifier = ">=0.23.7" }, { name = "python-dotenv", specifier = ">=1.0.0" }, { name = "rich", specifier = ">=13.7.1" }, { name = "sentry-sdk", extras = ["fastapi", "sqlalchemy"], specifier = ">=2.3.1" }, { name = "sqlalchemy", specifier = ">=2.0.30" }, - { name = "sqlalchemy-utils", marker = "extra == 'test'", specifier = ">=0.41.2" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "coverage", specifier = ">=7.6.0" }, + { name = "interrogate", specifier = ">=1.7.0" }, + { name = "pytest", specifier = ">=8.2.2" }, + { name = "pytest-asyncio", specifier = ">=0.23.7" }, + { name = "sqlalchemy-utils", specifier = ">=0.41.2" }, ] [[package]]