Skip to content

Commit

Permalink
Optimize warrior view
Browse files Browse the repository at this point in the history
  • Loading branch information
SupraSummus committed Oct 22, 2024
1 parent f495981 commit 08a3b8c
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 1 deletion.
6 changes: 6 additions & 0 deletions llm_wars/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,9 @@
},
},
}

ENABLE_DEBUG_TOOLBAR = env.bool('ENABLE_DEBUG_TOOLBAR', default=False)
if ENABLE_DEBUG_TOOLBAR:
INSTALLED_APPS.append('debug_toolbar')
MIDDLEWARE.insert(0, 'debug_toolbar.middleware.DebugToolbarMiddleware')
INTERNAL_IPS = ['127.0.0.1']
6 changes: 6 additions & 0 deletions llm_wars/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.conf import settings
from django.contrib import admin
from django.contrib.auth.views import LoginView, LogoutView
from django.urls import path, register_converter
Expand Down Expand Up @@ -80,3 +81,8 @@ def to_url(self, value):
path('logout/', LogoutView.as_view(), name='logout'),
path('signup/', SignupView.as_view(), name='signup'),
) + router.urls


if settings.ENABLE_DEBUG_TOOLBAR:
from debug_toolbar.toolbar import debug_toolbar_urls
urlpatterns += tuple(debug_toolbar_urls())
17 changes: 16 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pylint-django = "*"
flake8-pyproject = "*"
factory-boy = "*"
aider-chat = "*"
django-debug-toolbar = "^4.4.6"

[build-system]
requires = ["poetry-core"]
Expand Down
11 changes: 11 additions & 0 deletions warriors/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from users.tests.factories import UserFactory

from ..models import LLM, Arena, Battle, WarriorArena, WarriorUserPermission
from ..text_unit import TextUnit
from ..warriors import Warrior


Expand Down Expand Up @@ -50,3 +51,13 @@ class Meta:
arena = factory.SubFactory(ArenaFactory)
warrior_1 = factory.SubFactory(WarriorArenaFactory)
warrior_2 = factory.SubFactory(WarriorArenaFactory)


class TextUnitFactory(factory.django.DjangoModelFactory):
class Meta:
model = TextUnit

content = factory.Sequence(lambda n: f'factory-made text unit body {n}')
sha_256 = factory.LazyAttribute(
lambda o: hashlib.sha256(o.content.encode('utf-8')).digest()
)
3 changes: 3 additions & 0 deletions warriors/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.utils import timezone
from django_goals.models import schedule, worker_turn

from .. import embeddings
from ..exceptions import RateLimitError
from ..models import Battle, WarriorArena
from ..tasks import openai_client, resolve_battle_1_2
Expand All @@ -19,6 +20,8 @@ def test_submit_warrior_e2e(client, mocked_recaptcha, monkeypatch, default_arena
moderation_mock.return_value.results = [moderation_result_mock]
monkeypatch.setattr(openai_client.moderations, 'create', moderation_mock)

monkeypatch.setattr(embeddings, 'get_embedding', mock.MagicMock(return_value=[0.0] * 1024))

response = client.post(
reverse('warrior_create'),
data={
Expand Down
22 changes: 22 additions & 0 deletions warriors/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ..models import Battle
from ..text_unit import TextUnit
from .factories import BattleFactory, TextUnitFactory, WarriorArenaFactory


@pytest.mark.django_db
Expand Down Expand Up @@ -79,6 +80,27 @@ def test_warrior_details_authorized_session(client, warrior_arena, session_autho
assert (warrior_arena.body in response.content.decode()) == session_authorized


@pytest.mark.django_db
def test_warrior_details_do_few_sql_queries(client, arena, warrior_arena, django_assert_max_num_queries):
n = 100
for _ in range(n):
other_warrior_arena = WarriorArenaFactory(arena=arena)
if other_warrior_arena.id < warrior_arena.id:
kwargs = {'warrior_1': other_warrior_arena, 'warrior_2': warrior_arena}
else:
kwargs = {'warrior_1': warrior_arena, 'warrior_2': other_warrior_arena}
BattleFactory(
arena=arena,
**kwargs,
text_unit_1_2=TextUnitFactory(),
text_unit_2_1=TextUnitFactory(),
)
with django_assert_max_num_queries(n // 2):
client.get(
reverse('warrior_detail', args=(warrior_arena.id,))
)


@pytest.mark.django_db
def test_warrior_set_public_battle_results(user_client, warrior, warrior_arena, warrior_user_permission):
assert warrior.public_battle_results is False
Expand Down
4 changes: 4 additions & 0 deletions warriors/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def get_context_data(self, **kwargs):
'warrior_1__warrior',
'warrior_2',
'warrior_2__warrior',
'text_unit_1_2',
'text_unit_2_1',
).prefetch_related(
'arena',
)
context['battles'] = [
battle.get_warrior_viewpoint(self.object)
Expand Down

0 comments on commit 08a3b8c

Please sign in to comment.