Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
test for __main__
Browse files Browse the repository at this point in the history
Signed-off-by: Sarad Mohanan <[email protected]>
  • Loading branch information
sar009 committed Jan 10, 2024
1 parent 3ed4bce commit b2a1542
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 67 deletions.
75 changes: 42 additions & 33 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from copy import deepcopy
from datetime import datetime
from itertools import islice
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Sequence, Union

import click
import rich
Expand Down Expand Up @@ -361,9 +361,9 @@ def _get_dbs(
return db1, db2


def _set_age(options, min_age, max_age, db1) -> None:
def _set_age(options: dict, min_age: Optional[str], max_age: Optional[str], db: Database) -> None:
if min_age or max_age:
now: datetime = db1.query(current_timestamp(), datetime).replace(tzinfo=None)
now: datetime = db.query(current_timestamp(), datetime).replace(tzinfo=None)
try:
if max_age:
options["min_update"] = parse_time_before(now, max_age)
Expand All @@ -374,18 +374,18 @@ def _set_age(options, min_age, max_age, db1) -> None:


def _get_table_differ(
algorithm,
db1,
db2,
threaded,
threads,
assume_unique_key,
sample_exclusive_rows,
materialize_all_rows,
table_write_limit,
materialize_to_table,
bisection_factor,
bisection_threshold,
algorithm: str,
db1: Database,
db2: Database,
threaded: bool,
threads: int,
assume_unique_key: bool,
sample_exclusive_rows: bool,
materialize_all_rows: bool,
table_write_limit: int,
materialize_to_table: Optional[str],
bisection_factor: Optional[int],
bisection_threshold: Optional[int],
) -> TableDiffer:
algorithm = Algorithm(algorithm)
if algorithm == Algorithm.AUTO:
Expand All @@ -405,14 +405,14 @@ def _get_table_differ(
materialize_to_table and db1.dialect.parse_table_name(eval_name_template(materialize_to_table))
),
)
else:
assert algorithm == Algorithm.HASHDIFF
return HashDiffer(
bisection_factor=bisection_factor,
bisection_threshold=bisection_threshold,
threaded=threaded,
max_threadpool_size=threads and threads * 2,
)

assert algorithm == Algorithm.HASHDIFF
return HashDiffer(
bisection_factor=DEFAULT_BISECTION_FACTOR if bisection_factor is None else bisection_factor,
bisection_threshold=DEFAULT_BISECTION_THRESHOLD if bisection_threshold is None else bisection_threshold,
threaded=threaded,
max_threadpool_size=threads and threads * 2,
)


def _print_result(stats, json_output, diff_iter) -> None:
Expand All @@ -436,8 +436,18 @@ def _print_result(stats, json_output, diff_iter) -> None:
sys.stdout.flush()


def _get_expanded_columns(columns, case_sensitive, mutual, db1, schema1, table1, db2, schema2, table2) -> set:
expanded_columns = set()
def _get_expanded_columns(
columns: list[str],
case_sensitive: bool,
mutual: set[str],
db1: Database,
schema1: dict,
table1: str,
db2: Database,
schema2: dict,
table2: str,
) -> set[str]:
expanded_columns: set[str] = set()
for c in columns:
cc = c if case_sensitive else c.lower()
match = set(match_like(cc, mutual))
Expand All @@ -451,7 +461,7 @@ def _get_expanded_columns(columns, case_sensitive, mutual, db1, schema1, table1,
return expanded_columns


def _get_threads(threads, threads1, threads2) -> Tuple[bool, int]:
def _get_threads(threads: Union[int, str, None], threads1: Optional[int], threads2: Optional[int]) -> Tuple[bool, int]:
threaded = True
if threads is None:
threads = 1
Expand Down Expand Up @@ -519,9 +529,6 @@ def _data_diff(
return

key_columns = key_columns or ("id",)
bisection_factor = DEFAULT_BISECTION_FACTOR if bisection_factor is None else int(bisection_factor)
bisection_threshold = DEFAULT_BISECTION_THRESHOLD if bisection_threshold is None else int(bisection_threshold)

threaded, threads = _get_threads(threads, threads1, threads2)
start = time.monotonic()

Expand All @@ -531,12 +538,14 @@ def _data_diff(
)
return

db1: Database
db2: Database
db1, db2 = _get_dbs(threads, database1, threads1, database2, threads2, interactive)
with db1, db2:
options = dict(
case_sensitive=case_sensitive,
where=where,
)
options = {
"case_sensitive": case_sensitive,
"where": where,
}

_set_age(options, min_age, max_age, db1)
dbs: Tuple[Database, Database] = db1, db2
Expand Down
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def str_to_checksum(str: str):


class DiffTestCase(unittest.TestCase):
"Sets up two tables for diffing"
"""Sets up two tables for diffing"""

db_cls = None
src_schema = None
Expand Down
173 changes: 160 additions & 13 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,194 @@
import unittest

from data_diff import Database
from data_diff import Database, JoinDiffer, HashDiffer
from data_diff import databases as db
from data_diff.__main__ import _get_dbs
from tests.common import CONN_STRINGS
from data_diff.__main__ import _get_dbs, _set_age, _get_table_differ, _get_expanded_columns, _get_threads
from data_diff.databases.mysql import MySQL
from data_diff.diff_tables import TableDiffer
from tests.common import CONN_STRINGS, get_conn, DiffTestCase


class TestMain(unittest.TestCase):
def test__get_dbs(self):
class TestGetDBS(unittest.TestCase):
def test__get_dbs(self) -> None:
db1: Database
db2: Database
db1_str: str = CONN_STRINGS[db.PostgreSQL]
db2_str: str = CONN_STRINGS[db.PostgreSQL]

# no threads and 2 threads1 with no interactive
# no threads and 2 threads1
db1, db2 = _get_dbs(0, db1_str, 2, db2_str, 0, False)
with db1, db2:
assert db1 == db2
assert db1.thread_count == 2
assert not db1._interactive

# 3 threads and 0 threads1 with interactive
db1, db2 = _get_dbs(3, db1_str, 0, db2_str, 0, True)
# 3 threads and 0 threads1
db1, db2 = _get_dbs(3, db1_str, 0, db2_str, 0, False)
with db1, db2:
assert db1 == db2
assert db1.thread_count == 3

# not interactive
db1, db2 = _get_dbs(1, db1_str, 0, db2_str, 0, False)
with db1, db2:
assert db1 == db2
assert not db1._interactive

# interactive
db1, db2 = _get_dbs(1, db1_str, 0, db2_str, 0, True)
with db1, db2:
assert db1 == db2
assert db1._interactive

db2_str: str = CONN_STRINGS[db.MySQL]

# no threads and 1 threads1 and 2 thread2 with no interactive
# no threads and 1 threads1 and 2 thread2
db1, db2 = _get_dbs(0, db1_str, 1, db2_str, 2, False)
with db1, db2:
assert db1 != db2
assert db1.thread_count == 1
assert db2.thread_count == 2
assert not db1._interactive

# 3 threads and 0 threads1 and 0 thread2 with interactive
db1, db2 = _get_dbs(3, db1_str, 0, db2_str, 0, True)
# 3 threads and 0 threads1 and 0 thread2
db1, db2 = _get_dbs(3, db1_str, 0, db2_str, 0, False)
with db1, db2:
assert db1 != db2
assert db1.thread_count == 3
assert db2.thread_count == 3
assert db1.thread_count == db2.thread_count

# not interactive
db1, db2 = _get_dbs(1, db1_str, 0, db2_str, 0, False)
with db1, db2:
assert db1 != db2
assert not db1._interactive
assert not db2._interactive

# interactive
db1, db2 = _get_dbs(1, db1_str, 0, db2_str, 0, True)
with db1, db2:
assert db1 != db2
assert db1._interactive
assert db2._interactive


class TestSetAge(unittest.TestCase):
def setUp(self) -> None:
self.database: Database = get_conn(db.PostgreSQL)

def tearDown(self):
self.database.close()

def test__set_age(self):
options = {}
_set_age(options, None, None, self.database)
assert len(options) == 0

_set_age(options, "1d", None, self.database)
assert len(options) == 1
assert options.get("max_update") is not None

_set_age(options, None, "1d", self.database)
assert len(options) == 1
assert options.get("min_update") is not None

_set_age(options, "1d", "1d", self.database)
assert len(options) == 2
assert options.get("max_update") is not None
assert options.get("min_update") is not None


class TestGetTableDiffer(unittest.TestCase):
def test__get_table_differ(self):
db1: Database
db2: Database
db1_str: str = CONN_STRINGS[db.PostgreSQL]
db2_str: str = CONN_STRINGS[db.PostgreSQL]

db1, db2 = _get_dbs(1, db1_str, 0, db2_str, 0, False)
with db1, db2:
assert db1 == db2
table_differ: TableDiffer = self._get_differ("auto", db1, db2)
assert isinstance(table_differ, JoinDiffer)

table_differ: TableDiffer = self._get_differ("joindiff", db1, db2)
assert isinstance(table_differ, JoinDiffer)

table_differ: TableDiffer = self._get_differ("hashdiff", db1, db2)
assert isinstance(table_differ, HashDiffer)

db2_str: str = CONN_STRINGS[db.MySQL]
db1, db2 = _get_dbs(1, db1_str, 0, db2_str, 0, False)
with db1, db2:
assert db1 != db2
table_differ: TableDiffer = self._get_differ("auto", db1, db2)
assert isinstance(table_differ, HashDiffer)

table_differ: TableDiffer = self._get_differ("joindiff", db1, db2)
assert isinstance(table_differ, JoinDiffer)

table_differ: TableDiffer = self._get_differ("hashdiff", db1, db2)
assert isinstance(table_differ, HashDiffer)

@staticmethod
def _get_differ(algorithm, db1, db2):
return _get_table_differ(algorithm, db1, db2, False, 1, False, False, False, 1, None, None, None)


class TestGetExpandedColumns(DiffTestCase):
db_cls = MySQL

def setUp(self):
super().setUp()

def test__get_expanded_columns(self):
columns = ["user_id", "movie_id", "rating"]
kwargs = {
"db1": self.connection,
"schema1": self.src_schema,
"table1": self.table_src_name,
"db2": self.connection,
"schema2": self.dst_schema,
"table2": self.table_dst_name,
}
expanded_columns = _get_expanded_columns(columns, False, set(columns), **kwargs)

assert len(expanded_columns) == 3
assert len(set(expanded_columns) & set(columns)) == 3


class TestGetThreads(unittest.TestCase):
def test__get_threads(self):
threaded, threads = _get_threads(None, None, None)
assert threaded
assert threads == 1

threaded, threads = _get_threads(None, 2, 3)
assert threaded
assert threads == 1

threaded, threads = _get_threads("serial", None, None)
assert not threaded
assert threads == 1

with self.assertRaises(AssertionError):
_get_threads("serial", 1, 2)

threaded, threads = _get_threads("4", None, None)
assert threaded
assert threads == 4

with self.assertRaises(ValueError) as value_error:
_get_threads("auto", None, None)
assert str(value_error.exception) == "invalid literal for int() with base 10: 'auto'"

threaded, threads = _get_threads(5, None, None)
assert threaded
assert threads == 5

threaded, threads = _get_threads(6, 7, 8)
assert threaded
assert threads == 6

with self.assertRaises(ValueError) as value_error:
_get_threads(0, None, None)
assert str(value_error.exception) == "Error: threads must be >= 1"
Loading

0 comments on commit b2a1542

Please sign in to comment.