diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index c2018ae..e75d384 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -3,5 +3,5 @@ @pytest.fixture(scope="session") -def SPARK(): +def spark_session() -> SparkSession: return SparkSession.builder.appName("IntegrationTests").getOrCreate() diff --git a/tests/integration/test_distance_transformer.py b/tests/integration/test_distance_transformer.py index 5de4572..323c696 100644 --- a/tests/integration/test_distance_transformer.py +++ b/tests/integration/test_distance_transformer.py @@ -81,12 +81,13 @@ ] -def test_should_maintain_all_data_it_reads(SPARK) -> None: - given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders(SPARK) - given_dataframe = SPARK.read.parquet(given_ingest_folder) - distance_transformer.run(SPARK, given_ingest_folder, given_transform_folder) +def test_should_maintain_all_data_it_reads(spark_session: SparkSession) -> None: + given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders( + spark_session) + given_dataframe = spark_session.read.parquet(given_ingest_folder) + distance_transformer.run(spark_session, given_ingest_folder, given_transform_folder) - actual_dataframe = SPARK.read.parquet(given_transform_folder) + actual_dataframe = spark_session.read.parquet(given_transform_folder) actual_columns = set(actual_dataframe.columns) actual_schema = set(actual_dataframe.schema) expected_columns = set(given_dataframe.columns) @@ -97,12 +98,13 @@ def test_should_maintain_all_data_it_reads(SPARK) -> None: @pytest.mark.skip -def test_should_add_distance_column_with_calculated_distance(SPARK) -> None: - given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders(SPARK) - distance_transformer.run(SPARK, given_ingest_folder, given_transform_folder) +def test_should_add_distance_column_with_calculated_distance(spark_session: SparkSession) -> None: + given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders( + spark_session) + distance_transformer.run(spark_session, given_ingest_folder, given_transform_folder) - actual_dataframe = SPARK.read.parquet(given_transform_folder) - expected_dataframe = SPARK.createDataFrame( + actual_dataframe = spark_session.read.parquet(given_transform_folder) + expected_dataframe = spark_session.createDataFrame( [ SAMPLE_DATA[0] + [1.07], SAMPLE_DATA[1] + [0.92], diff --git a/tests/integration/test_ingest.py b/tests/integration/test_ingest.py index 366be44..db82a29 100644 --- a/tests/integration/test_ingest.py +++ b/tests/integration/test_ingest.py @@ -3,10 +3,12 @@ import tempfile from typing import Tuple, List +from pyspark.sql import SparkSession + from data_transformations.citibike import ingest -def test_should_sanitize_column_names(SPARK) -> None: +def test_should_sanitize_column_names(spark_session: SparkSession) -> None: given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders() input_csv_path = given_ingest_folder + 'input.csv' csv_content = [ @@ -15,10 +17,10 @@ def test_should_sanitize_column_names(SPARK) -> None: ['1', '5', '2'], ] __write_csv_file(input_csv_path, csv_content) - ingest.run(SPARK, input_csv_path, given_transform_folder) + ingest.run(spark_session, input_csv_path, given_transform_folder) - actual = SPARK.read.parquet(given_transform_folder) - expected = SPARK.createDataFrame( + actual = spark_session.read.parquet(given_transform_folder) + expected = spark_session.createDataFrame( [ ['3', '4', '1'], ['1', '5', '2'] diff --git a/tests/integration/test_validate_spark_environment.py b/tests/integration/test_validate_spark_environment.py index 984c170..e53c2b0 100644 --- a/tests/integration/test_validate_spark_environment.py +++ b/tests/integration/test_validate_spark_environment.py @@ -5,39 +5,44 @@ import pytest -def test_java_home_is_set(): +def test_java_home_is_set() -> None: java_home = os.environ.get("JAVA_HOME") - assert java_home is not None, "Environment variable 'JAVA_HOME' is not set but is required by pySpark to work." + assert java_home is not None, \ + "Environment variable 'JAVA_HOME' is not set but is required by pySpark to work." -def test_java_version_minimum_requirement(expected_major_version=11): +def test_java_version_minimum_requirement(expected_major_version: int = 11) -> None: version_line = __extract_version_line(__java_version_output()) major_version = __parse_major_version(version_line) - assert major_version >= expected_major_version, (f"Major version {major_version} is not recent enough, " - f"we need at least version {expected_major_version}.") + assert major_version >= expected_major_version, ( + f"Major version {major_version} is not recent enough, " + f"we need at least version {expected_major_version}.") -def __java_version_output(): - java_version = subprocess.check_output(['java', '-version'], stderr=subprocess.STDOUT).decode("utf-8") +def __java_version_output() -> str: + java_version = subprocess.check_output(['java', '-version'], stderr=subprocess.STDOUT).decode( + "utf-8") print(f"\n`java -version` returned\n{java_version}") return java_version -def __extract_version_line(java_version_output): - version_line = next((line for line in java_version_output.splitlines() if "version" in line), None) +def __extract_version_line(java_version_output: str) -> str: + version_line = next((line for line in java_version_output.splitlines() if "version" in line), + None) if not version_line: pytest.fail("Couldn't find version information in `java -version` output.") return version_line -def __parse_major_version(version_line): - version_regex = re.compile(r'version "(?P\d+)\.(?P\d+)\.\d+"') +# pylint: disable=R1710 +def __parse_major_version(version_line: str) -> int: + version_regex = re.compile(r'version "(?P\d+)\.(?P\d+)\.\w+"') match = version_regex.search(version_line) - if not match: - return None - major_version = int(match.group("major")) - if major_version == 1: - major_version = int(match.group("minor")) - if major_version is None: - pytest.fail(f"Couldn't parse Java version from {version_line}.") - return major_version + if match is not None: + major_version = int(match.group("major")) + if major_version == 1: + # we need to jump this hoop due to Java version naming conventions - it's fun: + # https://softwareengineering.stackexchange.com/questions/175075/why-is-java-version-1-x-referred-to-as-java-x + major_version = int(match.group("minor")) + return major_version + pytest.fail(f"Couldn't parse Java version from {version_line}.") diff --git a/tests/integration/test_word_count.py b/tests/integration/test_word_count.py index 8ff15a0..2e212f6 100644 --- a/tests/integration/test_word_count.py +++ b/tests/integration/test_word_count.py @@ -3,6 +3,7 @@ from typing import Tuple, List import pytest +from pyspark.sql import SparkSession from data_transformations.wordcount import word_count_transformer @@ -19,7 +20,7 @@ def _get_file_paths(input_file_lines: List[str]) -> Tuple[str, str]: @pytest.mark.skip -def test_should_tokenize_words_and_count_them(SPARK) -> None: +def test_should_tokenize_words_and_count_them(spark_session: SparkSession) -> None: lines = [ "In my younger and more vulnerable years my father gave me some advice that I've been " "turning over in my mind ever since. \"Whenever you feel like criticising any one,\"" @@ -46,9 +47,9 @@ def test_should_tokenize_words_and_count_them(SPARK) -> None: ] input_file_path, output_path = _get_file_paths(lines) - word_count_transformer.run(SPARK, input_file_path, output_path) + word_count_transformer.run(spark_session, input_file_path, output_path) - actual = SPARK.read.csv(output_path, header=True, inferSchema=True) + actual = spark_session.read.csv(output_path, header=True, inferSchema=True) expected_data = [ ["a", 4], ["across", 1], @@ -258,6 +259,6 @@ def test_should_tokenize_words_and_count_them(SPARK) -> None: ["you've", 1], ["younger", 1], ] - expected = SPARK.createDataFrame(expected_data, ["word", "count"]) + expected = spark_session.createDataFrame(expected_data, ["word", "count"]) assert actual.collect() == expected.collect()