Skip to content

Commit

Permalink
tests: unit test instantiation
Browse files Browse the repository at this point in the history
  • Loading branch information
Edward-Jackson-ONS committed Aug 19, 2024
1 parent fde8394 commit b930b94
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 0 deletions.
1 change: 1 addition & 0 deletions tests/readers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for the reader classes."""
263 changes: 263 additions & 0 deletions tests/readers/test_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
"""Unit tests for instantiation methods of our readers."""

import datetime as dt
from unittest import mock

from hypothesis import given, provisional
from hypothesis import strategies as st

from parliai_public.readers import Debates, WrittenAnswers

from ..common import ST_DATES, ST_FREE_TEXT, TODAY, ToyReader, where_what

ST_OPTIONAL_STRINGS = st.one_of((st.just(None), ST_FREE_TEXT))
YESTERDAY = TODAY - dt.timedelta(days=1)


@given(
st.sampled_from((ToyReader, Debates, WrittenAnswers)),
st.lists(provisional.urls(), max_size=5),
st.one_of((st.just(None), st.lists(ST_FREE_TEXT, max_size=5))),
st.one_of(st.just(None), st.lists(ST_DATES, min_size=1, max_size=5)),
ST_FREE_TEXT,
ST_OPTIONAL_STRINGS,
)
def test_init(reader_class, urls, terms, dates, outdir, prompt):
"""Test instantiation occurs correctly."""

where, what = where_what(reader_class)
if reader_class is WrittenAnswers:
urls = reader_class._supported_urls

config = {
"prompt": "",
"llm_name": "gemma",
}
with mock.patch(f"{where}._load_config") as load:
load.return_value = config
reader = reader_class(urls, terms, dates, outdir, prompt)

default_terms = ["Office for National Statistics", "ONS"]
assert isinstance(reader, what)
assert reader.urls == urls
assert reader.terms == default_terms if not terms else terms
assert reader.dates == [YESTERDAY] if dates is None else dates
assert reader.outdir == outdir
assert reader.prompt == ("" if prompt is None else prompt)
assert reader.llm_name == "gemma"

load.assert_called_once_with()


# # @given(
# # st.sampled_from((ToyReader, Debates, WrittenAnswers)),
# # ST_OPTIONAL_STRINGS,
# # st.lists(provisional.urls(), max_size=5),
# # st.lists(ST_FREE_TEXT, max_size=5),
# # ST_FREE_TEXT,
# # )
# # def test_from_toml_no_dates(reader_class, path, urls, terms, text):
# # """
# # Test that an instance can be made from a configuration file.

# # In this test, we do not configure any of the date parameters, so
# # every reader instance should have the same `dates` attribute:
# # yesterday.
# # """

# # where, what = where_what(reader_class)
# # if reader_class is WrittenAnswers:
# # urls = reader_class._supported_urls

# # with (
# # mock.patch(f"{where}._load_config") as loader,
# # mock.patch("parliai.dates.list_dates") as lister,
# # ):
# # loader.return_value = {
# # "urls": urls,
# # "terms": terms,
# # "outdir": text,
# # "prompt": text,
# # "llm_name": "gemini-pro",
# # }
# # reader = reader_class.from_toml(path)

# # assert isinstance(reader, what)
# # assert reader.urls == urls
# # assert reader.terms == terms
# # assert reader.dates == [YESTERDAY]
# # assert reader.outdir == text
# # assert reader.prompt == text
# # assert reader.llm_name == "gemini-pro"

# # assert loader.return_value["dates"] is None
# # assert loader.call_count == 2
# # assert loader.call_args_list == [mock.call(path), mock.call()]

# # lister.assert_not_called()


# # @given(
# # st.sampled_from((ToyReader, Debates, WrittenAnswers)),
# # ST_DATES.map(dt.date.isoformat),
# # )
# # def test_from_toml_with_start(reader_class, start):
# # """
# # Check the config constructor works with a start date.

# # The actual date list construction is mocked here.
# # """
# # where, what = where_what(reader_class)

# # with (
# # mock.patch(f"{where}._load_config") as loader,
# # mock.patch("parliai.dates.list_dates") as lister,
# # ):
# # loader.return_value = {
# # "urls": [],
# # "start": start,
# # "prompt": "",
# # "llm_name": "gemini-pro",
# # }
# # lister.return_value = ["dates"]
# # reader = reader_class.from_toml()

# # assert isinstance(reader, what)
# # assert reader.dates == ["dates"]

# # assert "start" not in loader.return_value
# # assert loader.return_value.get("dates") == ["dates"]

# # lister.assert_called_once_with(start, None, None, "%Y-%m-%d")


# # @given(
# # st.sampled_from((ToyReader, Debates, WrittenAnswers)),
# # ST_DATES.map(dt.date.isoformat),
# # )
# # def test_from_toml_with_end(reader_class, end):
# # """Check the config constructor works with an end date."""

# # where, what = where_what(reader_class)

# # with (
# # mock.patch(f"{where}._load_config") as loader,
# # mock.patch("parliai.dates.list_dates") as lister,
# # ):
# # loader.return_value = {
# # "urls": [],
# # "end": end,
# # "prompt": "",
# # "llm_name": "gemini-pro",
# # }
# # lister.return_value = ["dates"]
# # reader = reader_class.from_toml()

# # assert isinstance(reader, what)
# # assert reader.dates == ["dates"]

# # assert "end" not in loader.return_value
# # assert loader.return_value.get("dates") == ["dates"]

# # lister.assert_called_once_with(None, end, None, "%Y-%m-%d")


# # @given(
# # st.sampled_from((ToyReader, Debates, WrittenAnswers)),
# # st.tuples(ST_DATES, ST_DATES).map(
# # lambda dates: sorted(map(dt.date.isoformat, dates))
# # ),
# # )
# # def test_from_toml_with_endpoints(reader_class, endpoints):
# # """Check the config constructor works with two endpoints."""

# # where, what = where_what(reader_class)
# # start, end = endpoints

# # with (
# # mock.patch(f"{where}._load_config") as loader,
# # mock.patch("parliai.dates.list_dates") as lister,
# # ):
# # loader.return_value = {
# # "urls": [],
# # "start": start,
# # "end": end,
# # "prompt": "",
# # "llm_name": "gemini-pro",
# # }
# # lister.return_value = ["dates"]
# # reader = reader_class.from_toml()

# # assert isinstance(reader, what)
# # assert reader.dates == ["dates"]

# # assert "start" not in loader.return_value
# # assert "end" not in loader.return_value
# # assert loader.return_value.get("dates") == ["dates"]

# # lister.assert_called_once_with(start, end, None, "%Y-%m-%d")


# # @given(
# # st.sampled_from((ToyReader, Debates, WrittenAnswers)),
# # st.integers(1, 14),
# # )
# # def test_from_toml_with_window(reader_class, window):
# # """Check the config constructor works with a window."""

# # where, what = where_what(reader_class)

# # with (
# # mock.patch(f"{where}._load_config") as loader,
# # mock.patch("parliai.dates.list_dates") as lister,
# # ):
# # loader.return_value = {
# # "urls": [],
# # "window": window,
# # "prompt": "",
# # "llm_name": "gemini-pro",
# # }
# # lister.return_value = ["dates"]
# # reader = reader_class.from_toml()

# # assert isinstance(reader, what)
# # assert reader.dates == ["dates"]

# # assert "end" not in loader.return_value
# # assert loader.return_value.get("dates") == ["dates"]

# # lister.assert_called_once_with(None, None, window, "%Y-%m-%d")


# # @given(
# # st.sampled_from((ToyReader, Debates, WrittenAnswers)),
# # ST_DATES.map(dt.date.isoformat),
# # st.integers(1, 14),
# # )
# # def test_from_toml_with_end_and_window(reader_class, end, window):
# # """Check the config constructor works with an end and a window."""

# # where, what = where_what(reader_class)

# # with (
# # mock.patch(f"{where}._load_config") as loader,
# # mock.patch("parliai.dates.list_dates") as lister,
# # ):
# # loader.return_value = {
# # "urls": [],
# # "end": end,
# # "window": window,
# # "prompt": "",
# # "llm_name": "gemini-pro",
# # }
# # lister.return_value = ["dates"]
# # reader = reader_class.from_toml()

# # assert isinstance(reader, what)
# # assert reader.dates == ["dates"]

# # assert "end" not in loader.return_value
# # assert "window" not in loader.return_value
# # assert loader.return_value.get("dates") == ["dates"]

# # lister.assert_called_once_with(None, end, window, "%Y-%m-%d")

0 comments on commit b930b94

Please sign in to comment.