Skip to content

Commit

Permalink
Use battle llm instead of arena
Browse files Browse the repository at this point in the history
  • Loading branch information
SupraSummus committed Nov 10, 2024
1 parent b7f87ac commit 4ae14c5
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 8 deletions.
8 changes: 4 additions & 4 deletions warriors/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,21 +282,20 @@ class Meta:
class BattleQuerySet(models.QuerySet):
def with_warrior_arena(self, warrior_arena):
return self.filter(
arena_id=warrior_arena.arena_id,
llm=warrior_arena.arena.llm,
).filter(
models.Q(warrior_1_id=warrior_arena.warrior_id) |
models.Q(warrior_2_id=warrior_arena.warrior_id),
)

def with_warrior_arenas(self, warrior_arena_1, warrior_arena_2):
assert warrior_arena_1.arena_id == warrior_arena_2.arena_id
arena_id = warrior_arena_1.arena_id
warrior_1_id = warrior_arena_1.warrior_id
warrior_2_id = warrior_arena_2.warrior_id
if warrior_1_id > warrior_2_id:
warrior_1_id, warrior_2_id = warrior_2_id, warrior_1_id
return self.filter(
arena_id=arena_id,
llm=warrior_arena_1.arena.llm,
warrior_1_id=warrior_1_id,
warrior_2_id=warrior_2_id,
)
Expand Down Expand Up @@ -550,6 +549,7 @@ def map_field_name(self, field_name):
'id',
'arena',
'arena_id',
'llm',
'scheduled_at',
'rating_transferred_at',
):
Expand Down Expand Up @@ -683,7 +683,7 @@ def map_field_name(self, field_name):
return f'warrior_arena_{self.direction_from}'
elif field_name == 'warrior_arena_2':
return f'warrior_arena_{self.direction_to}'
elif field_name == 'arena':
elif field_name in ('arena', 'llm'):
return field_name
else:
return None
Expand Down
7 changes: 5 additions & 2 deletions warriors/rating_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_update_rating_takes_newer_battles(battle):
new_then = then + datetime.timedelta(days=1)
BattleFactory(
arena=battle.arena,
llm=battle.llm,
warrior_1=battle.warrior_1,
warrior_2=battle.warrior_2,
scheduled_at=new_then,
Expand Down Expand Up @@ -55,11 +56,12 @@ def test_rating_is_isolated_for_each_arena():
if warrior_1.id > warrior_2.id:
warrior_1, warrior_2 = warrior_2, warrior_1

arena_1 = ArenaFactory()
arena_1 = ArenaFactory(llm='model1')
warrior_1_arena_1 = WarriorArenaFactory(warrior=warrior_1, arena=arena_1)
warrior_2_arena_1 = WarriorArenaFactory(warrior=warrior_2, arena=arena_1)
BattleFactory(
arena=arena_1,
llm=arena_1.llm,
warrior_1=warrior_1,
warrior_2=warrior_2,
resolved_at_1_2=now,
Expand All @@ -70,11 +72,12 @@ def test_rating_is_isolated_for_each_arena():
lcs_len_2_1_2=1,
)

arena_2 = ArenaFactory()
arena_2 = ArenaFactory(llm='model2')
warrior_1_arena_2 = WarriorArenaFactory(warrior=warrior_1, arena=arena_2)
warrior_2_arena_2 = WarriorArenaFactory(warrior=warrior_2, arena=arena_2)
BattleFactory(
arena=arena_2,
llm=arena_2.llm,
warrior_1=warrior_1,
warrior_2=warrior_2,
resolved_at_1_2=now,
Expand Down
3 changes: 1 addition & 2 deletions warriors/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def resolve_battle_2_1(goal, battle_id):
def resolve_battle(battle_id, direction):
now = timezone.now()
battle = Battle.objects.filter(id=battle_id).select_related(
'arena',
'warrior_1',
'warrior_2',
).get()
Expand All @@ -141,7 +140,7 @@ def resolve_battle(battle_id, direction):
resolve_battle_function = {
LLM.OPENAI_GPT: resolve_battle_openai,
LLM.CLAUDE_3_HAIKU: anthropic.resolve_battle,
}[battle_view.arena.llm]
}[battle_view.llm]

try:
(
Expand Down
1 change: 1 addition & 0 deletions warriors/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def battle(
warrior, other_warrior = other_warrior, warrior
return BattleFactory(
arena=arena,
llm=arena.llm,
warrior_1=warrior,
warrior_2=other_warrior,
**getattr(request, 'param', {}),
Expand Down
1 change: 1 addition & 0 deletions warriors/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_create_battle_lots_of_games_played(warrior_arena, battle, other_warrior
BattleFactory.create_batch(
100,
arena=battle.arena,
llm=battle.llm,
warrior_1=battle.warrior_1,
warrior_2=battle.warrior_2,
)
Expand Down

0 comments on commit 4ae14c5

Please sign in to comment.