Skip to content

Commit

Permalink
refactor tests (#434)
Browse files Browse the repository at this point in the history
* refactor tests

* import watch from hf
  • Loading branch information
Ben Epstein authored Nov 2, 2022
1 parent 0439ab2 commit 7ba5846
Show file tree
Hide file tree
Showing 42 changed files with 79 additions and 71 deletions.
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"dataquality"

__version__ = "v0.7.2"
__version__ = "v0.7.3"

import os
import resource
Expand Down
3 changes: 3 additions & 0 deletions dataquality/integrations/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from dataquality.analytics import Analytics
from dataquality.clients.api import ApiClient
from dataquality.exceptions import GalileoException, GalileoWarning

# We add this here so users can `from dataquality.integrations.hf import watch`
from dataquality.integrations.transformers_trainer import watch # noqa: F401
from dataquality.schemas.hf import HFCol
from dataquality.schemas.ner import TaggingSchema
from dataquality.schemas.split import conform_split
Expand Down
2 changes: 2 additions & 0 deletions dataquality/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,8 @@ def _process_exported_dataframe(
if data_df[col].ndim > 1:
return data_df
pdf = data_df.to_pandas_df()
# The spans come back as json.dumps string data, we can load it for our users
# Back into JSON data so they get the actual span objects
if task_type == TaskType.text_ner and "spans" in pdf.columns:
pdf["spans"] = pdf["spans"].apply(json.loads)
return pdf
Expand Down
2 changes: 1 addition & 1 deletion tests/test_api.py → tests/clients/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dataquality.clients.api import ApiClient
from dataquality.exceptions import GalileoException
from dataquality.schemas.task_type import TaskType
from tests.utils.mock_request import (
from tests.test_utils.mock_request import (
EXISTING_PROJECT,
EXISTING_RUN,
FAKE_NEW_RUN,
Expand Down
12 changes: 7 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dataquality.loggers import BaseGalileoLogger
from dataquality.schemas.task_type import TaskType
from dataquality.utils.dq_logger import DQ_LOG_FILE_HOME
from tests.utils.mock_request import MockResponse
from tests.test_utils.mock_request import MockResponse

DEFAULT_API_URL = "http://localhost:8088"
DEFAULT_PROJECT_ID = UUID("399057bc-b276-4027-a5cf-48893ac45388")
Expand Down Expand Up @@ -52,7 +52,6 @@ def cleanup_after_use() -> Generator:
for task_type in list(TaskType):
dataquality.get_model_logger(task_type).logger_config.reset()
try:
dataquality.get_model_logger().logger_config.reset()
if os.path.isdir(BaseGalileoLogger.LOG_FILE_DIR):
shutil.rmtree(BaseGalileoLogger.LOG_FILE_DIR)
if not os.path.isdir(TEST_PATH):
Expand All @@ -63,9 +62,12 @@ def cleanup_after_use() -> Generator:
os.makedirs(DQ_LOG_FILE_LOCATION)
yield
finally:
shutil.rmtree(BaseGalileoLogger.LOG_FILE_DIR)
shutil.rmtree(DQ_LOG_FILE_LOCATION)
dataquality.get_model_logger().logger_config.reset()
if os.path.exists(BaseGalileoLogger.LOG_FILE_DIR):
shutil.rmtree(BaseGalileoLogger.LOG_FILE_DIR)
if os.path.exists(DQ_LOG_FILE_LOCATION):
shutil.rmtree(DQ_LOG_FILE_LOCATION)
for task_type in list(TaskType):
dataquality.get_model_logger(task_type).logger_config.reset()


@pytest.fixture()
Expand Down
5 changes: 4 additions & 1 deletion tests/test_auth.py → tests/core/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import dataquality
from dataquality.core.auth import GALILEO_AUTH_METHOD
from dataquality.exceptions import GalileoException
from tests.utils.mock_request import mocked_failed_login_requests, mocked_login_requests
from tests.test_utils.mock_request import (
mocked_failed_login_requests,
mocked_login_requests,
)

config = dataquality.config

Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/test_config.py → tests/core/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
url_is_localhost,
)
from dataquality.exceptions import GalileoException
from tests.utils.mock_request import MockResponse
from tests.test_utils.mock_request import MockResponse


def test_console_url(set_test_config: Callable) -> None:
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/test_init.py → tests/core/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dataquality.core.init import _Init
from dataquality.exceptions import GalileoException
from tests.exceptions import LoginInvoked
from tests.utils.mock_request import (
from tests.test_utils.mock_request import (
EXISTING_PROJECT,
EXISTING_RUN,
MockResponse,
Expand Down
10 changes: 0 additions & 10 deletions tests/inference/conftest.py

This file was deleted.

16 changes: 12 additions & 4 deletions tests/inference/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import List, Type
from typing import Callable, List, Type
from unittest import mock

import pytest
Expand All @@ -13,7 +13,9 @@


class TestSetSplitInference:
def test_set_split_inference(self) -> None:
def test_set_split_inference(
self, set_test_config: Callable, cleanup_after_use: Callable
) -> None:
assert not dataquality.get_data_logger().logger_config.inference_logged
dataquality.set_split("inference", "all-customers")
assert dataquality.get_data_logger().logger_config.cur_split == "inference"
Expand All @@ -22,7 +24,9 @@ def test_set_split_inference(self) -> None:
== "all-customers"
)

def test_set_split_inference_missing_inference_name(self) -> None:
def test_set_split_inference_missing_inference_name(
self, set_test_config: Callable, cleanup_after_use: Callable
) -> None:
with pytest.raises(ValidationError) as e:
dataquality.set_split("inference")

Expand Down Expand Up @@ -89,7 +93,11 @@ def test_base_model_logger_validate_inference_missing_inference_name(self) -> No
return_value="1234-abcd-5678",
)
def test_write_model_output_inference(
self, mock_uuid: mock.MagicMock, mock_save_file: mock.MagicMock
self,
mock_uuid: mock.MagicMock,
mock_save_file: mock.MagicMock,
set_test_config: Callable,
cleanup_after_use: Callable,
) -> None:
inference_data = {
"epoch": [None, None, None],
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from dataquality.schemas.ner import TaggingSchema
from dataquality.utils.hf_tokenizer import extract_gold_spans_at_word_level
from tests.utils.hf_integration_constants import (
from tests.test_utils.hf_integration_constants import (
ADJUSTED_TOKEN_DATA,
UNADJUSTED_TOKEN_DATA,
BILOUSequence,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
from dataquality.schemas.task_type import TaskType
from dataquality.utils.thread_pool import ThreadPoolManager
from tests.conftest import LOCATION
from tests.utils.hf_datasets_mock import mock_dataset, mock_dataset_repeat
from tests.utils.mock_request import mocked_create_project_run, mocked_get_project_run
from tests.test_utils.hf_datasets_mock import mock_dataset, mock_dataset_repeat
from tests.test_utils.mock_request import (
mocked_create_project_run,
mocked_get_project_run,
)

# Load models locally
try:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from dataquality.schemas.task_type import TaskType
from dataquality.utils.thread_pool import ThreadPoolManager
from tests.conftest import LOCATION
from tests.utils.spacy_integration import load_ner_data_from_local, train_model
from tests.utils.spacy_integration_constants import (
from tests.test_utils.spacy_integration import load_ner_data_from_local, train_model
from tests.test_utils.spacy_integration_constants import (
LONG_SHORT_DATA,
LONG_TRAIN_DATA,
MISALIGNED_SPAN_DATA,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/test_ner.py → tests/loggers/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dataquality.schemas.task_type import TaskType
from dataquality.utils.thread_pool import ThreadPoolManager
from tests.conftest import TEST_PATH
from tests.utils.ner_constants import (
from tests.test_utils.ner_constants import (
GOLD_SPANS,
LABELS,
NER_INPUT_DATA,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
37 changes: 36 additions & 1 deletion tests/test_dataquality.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
import dataquality.core._config
import dataquality.core.finish
from dataquality.exceptions import GalileoException, GalileoWarning, LogBatchError
from dataquality.loggers import BaseGalileoLogger
from dataquality.loggers.data_logger import BaseGalileoDataLogger
from dataquality.loggers.model_logger import BaseGalileoModelLogger
from dataquality.loggers.model_logger.text_classification import (
TextClassificationModelLogger,
)
from dataquality.schemas.task_type import TaskType
from dataquality.utils.thread_pool import ThreadPoolManager
from tests.conftest import TEST_PATH
from tests.utils.data_utils import (
from tests.test_utils.data_utils import (
NUM_LOGS,
NUM_RECORDS,
_log_text_classification_data,
Expand Down Expand Up @@ -751,3 +753,36 @@ def test_cloud_restricts_inference_mode(mock_cloud: MagicMock) -> None:
"accounts can access this feature. Please email us at [email protected] for "
"more information."
)


def test_attribute_subsets() -> None:
"""All potential logging fields used by all subclass loggers should be encapsulated
Any new logger that is created has a set of attributes that it expects from users.
The `BaseLoggerAttributes` from the BaseGalileoLogger should be the superset of
all child loggers.
"""
all_attrs = set(BaseGalileoLogger.get_valid_attributes())
sub_data_loggers = BaseGalileoDataLogger.__subclasses__()
data_logger_attrs = set(
[j for i in sub_data_loggers for j in i.get_valid_attributes()]
)
sub_model_loggers = BaseGalileoModelLogger.__subclasses__()
model_logger_attrs = set(
[j for i in sub_model_loggers for j in i.get_valid_attributes()]
)
all_sub_attrs = data_logger_attrs.union(model_logger_attrs)
assert all_attrs.issuperset(
all_sub_attrs
), f"Missing attrs: {all_sub_attrs - all_attrs}"


def test_int_labels(set_test_config: Callable) -> None:
dataquality.set_labels_for_run(labels=[1, 2, 3, 4, 5]) # type: ignore
assert dataquality.get_data_logger().logger_config.labels == [
"1",
"2",
"3",
"4",
"5",
]
39 changes: 0 additions & 39 deletions tests/test_logger.py

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
import numpy as np
import pandas as pd

Expand Down
Empty file removed tests/utils/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/test_version.py → tests/utils/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataquality import __version__
from dataquality.exceptions import GalileoException
from dataquality.utils import version
from tests.utils.mock_request import (
from tests.test_utils.mock_request import (
mocked_healthcheck_request,
mocked_healthcheck_request_new_api_version,
)
Expand Down

0 comments on commit 7ba5846

Please sign in to comment.