From 8be1e24b8a66bad4af564d0738818a7ba93a002f Mon Sep 17 00:00:00 2001 From: Nick Engelhardt Date: Tue, 5 Nov 2024 08:38:14 +0100 Subject: [PATCH 1/5] Add missing type annotations --- tests/integration/conftest.py | 2 +- tests/integration/test_distance_transformer.py | 4 ++-- tests/integration/test_ingest.py | 4 +++- tests/integration/test_validate_spark_environment.py | 11 ++++++----- tests/integration/test_word_count.py | 3 ++- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index c2018ae..3a5342f 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -3,5 +3,5 @@ @pytest.fixture(scope="session") -def SPARK(): +def SPARK() -> SparkSession: return SparkSession.builder.appName("IntegrationTests").getOrCreate() diff --git a/tests/integration/test_distance_transformer.py b/tests/integration/test_distance_transformer.py index 5de4572..a507bb2 100644 --- a/tests/integration/test_distance_transformer.py +++ b/tests/integration/test_distance_transformer.py @@ -81,7 +81,7 @@ ] -def test_should_maintain_all_data_it_reads(SPARK) -> None: +def test_should_maintain_all_data_it_reads(SPARK: SparkSession) -> None: given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders(SPARK) given_dataframe = SPARK.read.parquet(given_ingest_folder) distance_transformer.run(SPARK, given_ingest_folder, given_transform_folder) @@ -97,7 +97,7 @@ def test_should_maintain_all_data_it_reads(SPARK) -> None: @pytest.mark.skip -def test_should_add_distance_column_with_calculated_distance(SPARK) -> None: +def test_should_add_distance_column_with_calculated_distance(SPARK: SparkSession) -> None: given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders(SPARK) distance_transformer.run(SPARK, given_ingest_folder, given_transform_folder) diff --git a/tests/integration/test_ingest.py b/tests/integration/test_ingest.py index 366be44..8e3cfdb 100644 --- a/tests/integration/test_ingest.py +++ b/tests/integration/test_ingest.py @@ -3,10 +3,12 @@ import tempfile from typing import Tuple, List +from pyspark.sql import SparkSession + from data_transformations.citibike import ingest -def test_should_sanitize_column_names(SPARK) -> None: +def test_should_sanitize_column_names(SPARK: SparkSession) -> None: given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders() input_csv_path = given_ingest_folder + 'input.csv' csv_content = [ diff --git a/tests/integration/test_validate_spark_environment.py b/tests/integration/test_validate_spark_environment.py index 984c170..f85c31c 100644 --- a/tests/integration/test_validate_spark_environment.py +++ b/tests/integration/test_validate_spark_environment.py @@ -1,36 +1,37 @@ import os import re import subprocess +from typing import Optional import pytest -def test_java_home_is_set(): +def test_java_home_is_set() -> None: java_home = os.environ.get("JAVA_HOME") assert java_home is not None, "Environment variable 'JAVA_HOME' is not set but is required by pySpark to work." -def test_java_version_minimum_requirement(expected_major_version=11): +def test_java_version_minimum_requirement(expected_major_version:int =11) -> None: version_line = __extract_version_line(__java_version_output()) major_version = __parse_major_version(version_line) assert major_version >= expected_major_version, (f"Major version {major_version} is not recent enough, " f"we need at least version {expected_major_version}.") -def __java_version_output(): +def __java_version_output() -> str: java_version = subprocess.check_output(['java', '-version'], stderr=subprocess.STDOUT).decode("utf-8") print(f"\n`java -version` returned\n{java_version}") return java_version -def __extract_version_line(java_version_output): +def __extract_version_line(java_version_output:str) -> str: version_line = next((line for line in java_version_output.splitlines() if "version" in line), None) if not version_line: pytest.fail("Couldn't find version information in `java -version` output.") return version_line -def __parse_major_version(version_line): +def __parse_major_version(version_line:str) -> Optional[int]: version_regex = re.compile(r'version "(?P\d+)\.(?P\d+)\.\d+"') match = version_regex.search(version_line) if not match: diff --git a/tests/integration/test_word_count.py b/tests/integration/test_word_count.py index 8ff15a0..8dbc620 100644 --- a/tests/integration/test_word_count.py +++ b/tests/integration/test_word_count.py @@ -3,6 +3,7 @@ from typing import Tuple, List import pytest +from pyspark.sql import SparkSession from data_transformations.wordcount import word_count_transformer @@ -19,7 +20,7 @@ def _get_file_paths(input_file_lines: List[str]) -> Tuple[str, str]: @pytest.mark.skip -def test_should_tokenize_words_and_count_them(SPARK) -> None: +def test_should_tokenize_words_and_count_them(SPARK: SparkSession) -> None: lines = [ "In my younger and more vulnerable years my father gave me some advice that I've been " "turning over in my mind ever since. \"Whenever you feel like criticising any one,\"" From 1da31db179c991bcbf4832d5dba318888e7e26ec Mon Sep 17 00:00:00 2001 From: Nick Engelhardt Date: Tue, 5 Nov 2024 08:48:21 +0100 Subject: [PATCH 2/5] Remove inconsistency in return types --- .../test_validate_spark_environment.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/integration/test_validate_spark_environment.py b/tests/integration/test_validate_spark_environment.py index f85c31c..5787609 100644 --- a/tests/integration/test_validate_spark_environment.py +++ b/tests/integration/test_validate_spark_environment.py @@ -31,14 +31,12 @@ def __extract_version_line(java_version_output:str) -> str: return version_line -def __parse_major_version(version_line:str) -> Optional[int]: +def __parse_major_version(version_line:str) -> int: version_regex = re.compile(r'version "(?P\d+)\.(?P\d+)\.\d+"') match = version_regex.search(version_line) - if not match: - return None - major_version = int(match.group("major")) - if major_version == 1: - major_version = int(match.group("minor")) - if major_version is None: - pytest.fail(f"Couldn't parse Java version from {version_line}.") - return major_version + if match: + major_version = int(match.group("major")) + if major_version == 1: # Java 8 reports version as 1.actual_major.actual_minor + major_version = int(match.group("minor")) + return major_version + pytest.fail(f"Couldn't parse Java version from {version_line}.") From 7f23603a95af6a8c683ec8a2790f856e3f341c25 Mon Sep 17 00:00:00 2001 From: Nick Engelhardt Date: Tue, 5 Nov 2024 09:06:35 +0100 Subject: [PATCH 3/5] Fix incorrect regex The previous regex wasn't working for java version 8, because that's reported as e.g. 'openjdk version "1.8.0_412"' and the underscore doesn't match regex pattern '\d+'. --- .../test_validate_spark_environment.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/integration/test_validate_spark_environment.py b/tests/integration/test_validate_spark_environment.py index 5787609..9e8720f 100644 --- a/tests/integration/test_validate_spark_environment.py +++ b/tests/integration/test_validate_spark_environment.py @@ -1,7 +1,6 @@ import os import re import subprocess -from typing import Optional import pytest @@ -11,7 +10,7 @@ def test_java_home_is_set() -> None: assert java_home is not None, "Environment variable 'JAVA_HOME' is not set but is required by pySpark to work." -def test_java_version_minimum_requirement(expected_major_version:int =11) -> None: +def test_java_version_minimum_requirement(expected_major_version: int = 11) -> None: version_line = __extract_version_line(__java_version_output()) major_version = __parse_major_version(version_line) assert major_version >= expected_major_version, (f"Major version {major_version} is not recent enough, " @@ -24,19 +23,21 @@ def __java_version_output() -> str: return java_version -def __extract_version_line(java_version_output:str) -> str: +def __extract_version_line(java_version_output: str) -> str: version_line = next((line for line in java_version_output.splitlines() if "version" in line), None) if not version_line: pytest.fail("Couldn't find version information in `java -version` output.") return version_line -def __parse_major_version(version_line:str) -> int: - version_regex = re.compile(r'version "(?P\d+)\.(?P\d+)\.\d+"') +def __parse_major_version(version_line: str) -> int: + version_regex = re.compile(r'version "(?P\d+)\.(?P\d+)\.\w+"') match = version_regex.search(version_line) - if match: + if match is not None: major_version = int(match.group("major")) - if major_version == 1: # Java 8 reports version as 1.actual_major.actual_minor + if major_version == 1: + # we need to jump this hoop due to Java version naming conventions - it's fun: + # https://softwareengineering.stackexchange.com/questions/175075/why-is-java-version-1-x-referred-to-as-java-x major_version = int(match.group("minor")) return major_version pytest.fail(f"Couldn't parse Java version from {version_line}.") From 02c95b68e77d861a3ca7641dd5071276753db50b Mon Sep 17 00:00:00 2001 From: Nick Engelhardt Date: Tue, 5 Nov 2024 09:12:56 +0100 Subject: [PATCH 4/5] Fix most linter suggestions --- tests/integration/conftest.py | 2 +- .../integration/test_distance_transformer.py | 22 ++++++++++--------- tests/integration/test_ingest.py | 8 +++---- .../test_validate_spark_environment.py | 14 +++++++----- tests/integration/test_word_count.py | 8 +++---- 5 files changed, 30 insertions(+), 24 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 3a5342f..e75d384 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -3,5 +3,5 @@ @pytest.fixture(scope="session") -def SPARK() -> SparkSession: +def spark_session() -> SparkSession: return SparkSession.builder.appName("IntegrationTests").getOrCreate() diff --git a/tests/integration/test_distance_transformer.py b/tests/integration/test_distance_transformer.py index a507bb2..323c696 100644 --- a/tests/integration/test_distance_transformer.py +++ b/tests/integration/test_distance_transformer.py @@ -81,12 +81,13 @@ ] -def test_should_maintain_all_data_it_reads(SPARK: SparkSession) -> None: - given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders(SPARK) - given_dataframe = SPARK.read.parquet(given_ingest_folder) - distance_transformer.run(SPARK, given_ingest_folder, given_transform_folder) +def test_should_maintain_all_data_it_reads(spark_session: SparkSession) -> None: + given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders( + spark_session) + given_dataframe = spark_session.read.parquet(given_ingest_folder) + distance_transformer.run(spark_session, given_ingest_folder, given_transform_folder) - actual_dataframe = SPARK.read.parquet(given_transform_folder) + actual_dataframe = spark_session.read.parquet(given_transform_folder) actual_columns = set(actual_dataframe.columns) actual_schema = set(actual_dataframe.schema) expected_columns = set(given_dataframe.columns) @@ -97,12 +98,13 @@ def test_should_maintain_all_data_it_reads(SPARK: SparkSession) -> None: @pytest.mark.skip -def test_should_add_distance_column_with_calculated_distance(SPARK: SparkSession) -> None: - given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders(SPARK) - distance_transformer.run(SPARK, given_ingest_folder, given_transform_folder) +def test_should_add_distance_column_with_calculated_distance(spark_session: SparkSession) -> None: + given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders( + spark_session) + distance_transformer.run(spark_session, given_ingest_folder, given_transform_folder) - actual_dataframe = SPARK.read.parquet(given_transform_folder) - expected_dataframe = SPARK.createDataFrame( + actual_dataframe = spark_session.read.parquet(given_transform_folder) + expected_dataframe = spark_session.createDataFrame( [ SAMPLE_DATA[0] + [1.07], SAMPLE_DATA[1] + [0.92], diff --git a/tests/integration/test_ingest.py b/tests/integration/test_ingest.py index 8e3cfdb..db82a29 100644 --- a/tests/integration/test_ingest.py +++ b/tests/integration/test_ingest.py @@ -8,7 +8,7 @@ from data_transformations.citibike import ingest -def test_should_sanitize_column_names(SPARK: SparkSession) -> None: +def test_should_sanitize_column_names(spark_session: SparkSession) -> None: given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders() input_csv_path = given_ingest_folder + 'input.csv' csv_content = [ @@ -17,10 +17,10 @@ def test_should_sanitize_column_names(SPARK: SparkSession) -> None: ['1', '5', '2'], ] __write_csv_file(input_csv_path, csv_content) - ingest.run(SPARK, input_csv_path, given_transform_folder) + ingest.run(spark_session, input_csv_path, given_transform_folder) - actual = SPARK.read.parquet(given_transform_folder) - expected = SPARK.createDataFrame( + actual = spark_session.read.parquet(given_transform_folder) + expected = spark_session.createDataFrame( [ ['3', '4', '1'], ['1', '5', '2'] diff --git a/tests/integration/test_validate_spark_environment.py b/tests/integration/test_validate_spark_environment.py index 9e8720f..ae930c6 100644 --- a/tests/integration/test_validate_spark_environment.py +++ b/tests/integration/test_validate_spark_environment.py @@ -7,24 +7,28 @@ def test_java_home_is_set() -> None: java_home = os.environ.get("JAVA_HOME") - assert java_home is not None, "Environment variable 'JAVA_HOME' is not set but is required by pySpark to work." + assert java_home is not None, \ + "Environment variable 'JAVA_HOME' is not set but is required by pySpark to work." def test_java_version_minimum_requirement(expected_major_version: int = 11) -> None: version_line = __extract_version_line(__java_version_output()) major_version = __parse_major_version(version_line) - assert major_version >= expected_major_version, (f"Major version {major_version} is not recent enough, " - f"we need at least version {expected_major_version}.") + assert major_version >= expected_major_version, ( + f"Major version {major_version} is not recent enough, " + f"we need at least version {expected_major_version}.") def __java_version_output() -> str: - java_version = subprocess.check_output(['java', '-version'], stderr=subprocess.STDOUT).decode("utf-8") + java_version = subprocess.check_output(['java', '-version'], stderr=subprocess.STDOUT).decode( + "utf-8") print(f"\n`java -version` returned\n{java_version}") return java_version def __extract_version_line(java_version_output: str) -> str: - version_line = next((line for line in java_version_output.splitlines() if "version" in line), None) + version_line = next((line for line in java_version_output.splitlines() if "version" in line), + None) if not version_line: pytest.fail("Couldn't find version information in `java -version` output.") return version_line diff --git a/tests/integration/test_word_count.py b/tests/integration/test_word_count.py index 8dbc620..2e212f6 100644 --- a/tests/integration/test_word_count.py +++ b/tests/integration/test_word_count.py @@ -20,7 +20,7 @@ def _get_file_paths(input_file_lines: List[str]) -> Tuple[str, str]: @pytest.mark.skip -def test_should_tokenize_words_and_count_them(SPARK: SparkSession) -> None: +def test_should_tokenize_words_and_count_them(spark_session: SparkSession) -> None: lines = [ "In my younger and more vulnerable years my father gave me some advice that I've been " "turning over in my mind ever since. \"Whenever you feel like criticising any one,\"" @@ -47,9 +47,9 @@ def test_should_tokenize_words_and_count_them(SPARK: SparkSession) -> None: ] input_file_path, output_path = _get_file_paths(lines) - word_count_transformer.run(SPARK, input_file_path, output_path) + word_count_transformer.run(spark_session, input_file_path, output_path) - actual = SPARK.read.csv(output_path, header=True, inferSchema=True) + actual = spark_session.read.csv(output_path, header=True, inferSchema=True) expected_data = [ ["a", 4], ["across", 1], @@ -259,6 +259,6 @@ def test_should_tokenize_words_and_count_them(SPARK: SparkSession) -> None: ["you've", 1], ["younger", 1], ] - expected = SPARK.createDataFrame(expected_data, ["word", "count"]) + expected = spark_session.createDataFrame(expected_data, ["word", "count"]) assert actual.collect() == expected.collect() From 604a870b06b778e016315f6dfbaf5ffd0ad30041 Mon Sep 17 00:00:00 2001 From: Nick Engelhardt Date: Tue, 5 Nov 2024 09:13:33 +0100 Subject: [PATCH 5/5] Disable inconsistent return type check In this case, if major version is not returned from the block, the function will never return but fail - so the check is not correct. --- tests/integration/test_validate_spark_environment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/test_validate_spark_environment.py b/tests/integration/test_validate_spark_environment.py index ae930c6..e53c2b0 100644 --- a/tests/integration/test_validate_spark_environment.py +++ b/tests/integration/test_validate_spark_environment.py @@ -34,6 +34,7 @@ def __extract_version_line(java_version_output: str) -> str: return version_line +# pylint: disable=R1710 def __parse_major_version(version_line: str) -> int: version_regex = re.compile(r'version "(?P\d+)\.(?P\d+)\.\w+"') match = version_regex.search(version_line)