Skip to content

Commit

Permalink
Import simplification & dev dependencies (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan authored Oct 25, 2024
1 parent 42b7d92 commit 17e0795
Show file tree
Hide file tree
Showing 19 changed files with 187 additions and 79 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/codeflash.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ jobs:
with:
enable-cache: true
- if: steps.bot_check.outputs.skip_remaining_steps == 'no'
run: uv sync
run: |-
uv sync
uv pip install codeflash
- name: Run CodeFlash on fhaviary
if: steps.bot_check.outputs.skip_remaining_steps == 'no'
run: uv run codeflash
Expand Down
2 changes: 1 addition & 1 deletion packages/gsm8k/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ where = ["src"]

[tool.setuptools_scm]
root = "../.."
version_file = "src/aviary/gsm8k/version.py"
version_file = "src/aviary/envs/gsm8k/version.py"
File renamed without changes.
File renamed without changes.
File renamed without changes.
5 changes: 2 additions & 3 deletions packages/gsm8k/tests/test_gsm8k_env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest

from aviary.env import Environment, TaskDataset
from aviary.gsm8k.env import CalculatorEnv, CalculatorEnvConfig
from aviary.tools import ToolCall, ToolRequestMessage
from aviary.core import Environment, TaskDataset, ToolCall, ToolRequestMessage
from aviary.envs.gsm8k import CalculatorEnv, CalculatorEnvConfig


@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion packages/hotpotqa/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ where = ["src"]

[tool.setuptools_scm]
root = "../.."
version_file = "src/aviary/hotpotqa/version.py"
version_file = "src/aviary/envs/hotpotqa/version.py"
4 changes: 2 additions & 2 deletions packages/hotpotqa/tests/test_hotpotqa_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from aviary.env import Environment, TaskDataset
from aviary.hotpotqa import HotPotQAEnv
from aviary.core import Environment, TaskDataset
from aviary.envs.hotpotqa import HotPotQAEnv


def test_env_construction() -> None:
Expand Down
43 changes: 23 additions & 20 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,28 @@ requires-python = ">=3.11"
cloud = [
"boto3",
]
dev = [
"SQLAlchemy[aiosqlite]~=2.0", # Match aviary dependencies
"aviary.gsm8k[typing]", # So `uv sync` pulls this in, and for type stubs
"aviary.hotpotqa", # So `uv sync` pulls this in
"codeflash",
"fhaviary[image,llm,server,typing,xml]",
"ipython>=8", # Pin to keep recent
"mypy>=1.8", # Pin for mutable-override
"pre-commit>=3.4", # Pin to keep recent
"pydantic~=2.9", # Pydantic 2.9 changed JSON schema exports 'allOf', so ensure tests match
"pylint-pydantic",
"pylint>=3.2",
"pytest-asyncio",
"pytest-recording",
"pytest-subtests",
"pytest-sugar",
"pytest-timer[colorama]",
"pytest-xdist",
"pytest>=8", # Pin to keep recent
"refurb>=2", # Pin to keep recent
"typeguard",
]
gsm8k = ["aviary.gsm8k"]
hotpotqa = ["aviary.hotpotqa"]
image = [
Expand Down Expand Up @@ -401,26 +423,7 @@ trailing_comma_inline_array = true

[tool.uv]
dev-dependencies = [
"SQLAlchemy[aiosqlite]~=2.0", # Match aviary dependencies
"aviary.gsm8k[typing]", # So `uv sync` pulls this in, and for type stubs
"aviary.hotpotqa", # So `uv sync` pulls this in
"codeflash",
"fhaviary[image,llm,server,typing,xml]",
"ipython>=8", # Pin to keep recent
"mypy>=1.8", # Pin for mutable-override
"pre-commit>=3.4", # Pin to keep recent
"pydantic~=2.9", # Pydantic 2.9 changed JSON schema exports 'allOf', so ensure tests match
"pylint-pydantic",
"pylint>=3.2",
"pytest-asyncio",
"pytest-recording",
"pytest-subtests",
"pytest-sugar",
"pytest-timer[colorama]",
"pytest-xdist",
"pytest>=8", # Pin to keep recent
"refurb>=2", # Pin to keep recent
"typeguard",
"fhaviary[dev]",
]

[tool.uv.sources]
Expand Down
69 changes: 69 additions & 0 deletions src/aviary/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from aviary.env import (
TASK_DATASET_REGISTRY,
DummyEnv,
DummyEnvState,
DummyTaskDataset,
Environment,
Frame,
TaskConfig,
TaskDataset,
)
from aviary.env_client import EnvironmentClient
from aviary.message import MalformedMessageError, Message, join
from aviary.render import Renderer
from aviary.tools import (
INVALID_TOOL_NAME,
FunctionInfo,
Messages,
MessagesAdapter,
Parameters,
Tool,
ToolCall,
ToolCallFunction,
ToolRequestMessage,
ToolResponseMessage,
Tools,
ToolsAdapter,
ToolSelector,
ToolSelectorLedger,
argref_by_name,
eval_answer,
wraps_doc_only,
)
from aviary.utils import encode_image_to_base64, is_coroutine_callable, partial_format

__all__ = [
"INVALID_TOOL_NAME",
"TASK_DATASET_REGISTRY",
"DummyEnv",
"DummyEnvState",
"DummyTaskDataset",
"Environment",
"EnvironmentClient",
"Frame",
"FunctionInfo",
"MalformedMessageError",
"Message",
"Messages",
"MessagesAdapter",
"Parameters",
"Renderer",
"TaskConfig",
"TaskDataset",
"Tool",
"ToolCall",
"ToolCallFunction",
"ToolRequestMessage",
"ToolResponseMessage",
"ToolSelector",
"ToolSelectorLedger",
"Tools",
"ToolsAdapter",
"argref_by_name",
"encode_image_to_base64",
"eval_answer",
"is_coroutine_callable",
"join",
"partial_format",
"wraps_doc_only",
]
8 changes: 4 additions & 4 deletions src/aviary/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ def available(cls) -> set[str]:
# Maps baseline environment names to their module and class names
ENV_REGISTRY: dict[str, tuple[str, str]] = {
"dummy": ("aviary.env", "DummyEnv"),
"calculator": ("aviary.gsm8k.env", "CalculatorEnv"),
"hotpotqa": ("aviary.hotpotqa.env", "HotPotQAEnv"),
"calculator": ("aviary.envs.gsm8k.env", "CalculatorEnv"),
"hotpotqa": ("aviary.envs.hotpotqa.env", "HotPotQAEnv"),
}

TEnvironment = TypeVar("TEnvironment", bound=Environment)
Expand Down Expand Up @@ -352,8 +352,8 @@ def iter_batches(
# Maps baseline task dataset names to their module and class names
TASK_DATASET_REGISTRY: dict[str, tuple[str, str]] = {
"dummy": ("aviary.env", "DummyTaskDataset"),
"gsm8k": ("aviary.gsm8k.env", "GSM8kDataset"),
"hotpotqa": ("aviary.hotpotqa.env", "HotPotQADataset"),
"gsm8k": ("aviary.envs.gsm8k.env", "GSM8kDataset"),
"hotpotqa": ("aviary.envs.hotpotqa.env", "HotPotQADataset"),
}


Expand Down
12 changes: 8 additions & 4 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
from pydantic import BaseModel
from pytest_subtests import SubTests

from aviary.env import DummyEnv, DummyEnvState, Environment, Frame, TaskDataset
from aviary.message import Message
from aviary.render import Renderer
from aviary.tools import (
from aviary.core import (
DummyEnv,
DummyEnvState,
Environment,
Frame,
Message,
Renderer,
TaskDataset,
Tool,
ToolCall,
ToolRequestMessage,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import numpy as np
import pytest

from aviary.message import Message
from aviary.tools import (
from aviary.core import (
Message,
ToolCall,
ToolCallFunction,
ToolRequestMessage,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from pytest_subtests import SubTests
from typeguard import suppress_type_checks

from aviary.env import DummyEnv, Environment
from aviary.message import Message
from aviary.tools import (
from aviary.core import (
INVALID_TOOL_NAME,
DummyEnv,
Environment,
FunctionInfo,
Message,
Tool,
ToolCall,
ToolRequestMessage,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from aviary.tools import eval_answer
from aviary.core import eval_answer


@pytest.mark.vcr
Expand Down
Loading

0 comments on commit 17e0795

Please sign in to comment.