diff --git a/backend_py/primary/primary/services/summary_delta_vectors.py b/backend_py/primary/primary/services/summary_delta_vectors.py index 470cb61e7..572190758 100644 --- a/backend_py/primary/primary/services/summary_delta_vectors.py +++ b/backend_py/primary/primary/services/summary_delta_vectors.py @@ -4,7 +4,7 @@ import pyarrow.compute as pc import numpy as np -from .utils.summary_vector_table_helpers import validate_summary_vector_table +from .utils.summary_vector_table_helpers import validate_summary_vector_table_pa @dataclass @@ -29,8 +29,8 @@ def create_delta_vector_table( `Note`: Pre-processing of DATE-columns, e.g. resampling, should be done before calling this function. """ - validate_summary_vector_table(first_vector_table, vector_name) - validate_summary_vector_table(second_vector_table, vector_name) + validate_summary_vector_table_pa(first_vector_table, vector_name) + validate_summary_vector_table_pa(second_vector_table, vector_name) joined_vector_table = first_vector_table.join( second_vector_table, keys=["DATE", "REAL"], join_type="inner", right_suffix="_second" @@ -56,7 +56,7 @@ def create_realization_delta_vector_list( """ Create a list of RealizationDeltaVector from the delta vector table. """ - validate_summary_vector_table(delta_vector_table, vector_name) + validate_summary_vector_table_pa(delta_vector_table, vector_name) real_arr_np = delta_vector_table.column("REAL").to_numpy() unique_reals, first_occurrence_idx, real_counts = np.unique(real_arr_np, return_index=True, return_counts=True) diff --git a/backend_py/primary/primary/services/sumo_access/summary_access.py b/backend_py/primary/primary/services/sumo_access/summary_access.py index f636a8c3d..de5052ac7 100644 --- a/backend_py/primary/primary/services/sumo_access/summary_access.py +++ b/backend_py/primary/primary/services/sumo_access/summary_access.py @@ -16,7 +16,7 @@ sort_table_on_real_then_date, is_date_column_monotonically_increasing, ) -from primary.services.utils.summary_vector_table_helpers import validate_summary_vector_table +from primary.services.utils.summary_vector_table_helpers import validate_summary_vector_table_pa from primary.services.service_exceptions import ( Service, NoDataError, @@ -164,7 +164,7 @@ async def get_vector_async( table, vector_metadata = await self.get_vector_table_async(vector_name, resampling_frequency, realizations) # Verify that columns are as we expect - validate_summary_vector_table(table, vector_name, Service.SUMO) + validate_summary_vector_table_pa(table, vector_name, Service.SUMO) real_arr_np = table.column("REAL").to_numpy() unique_reals, first_occurrence_idx, real_counts = np.unique(real_arr_np, return_index=True, return_counts=True) diff --git a/backend_py/primary/primary/services/utils/summary_vector_table_helpers.py b/backend_py/primary/primary/services/utils/summary_vector_table_helpers.py index b866713dd..1a1bb9554 100644 --- a/backend_py/primary/primary/services/utils/summary_vector_table_helpers.py +++ b/backend_py/primary/primary/services/utils/summary_vector_table_helpers.py @@ -3,15 +3,18 @@ from primary.services.service_exceptions import InvalidDataError, Service -def validate_summary_vector_table(vector_table: pa.Table, vector_name: str, service: Service = Service.GENERAL) -> None: +def validate_summary_vector_table_pa( + vector_table: pa.Table, vector_name: str, service: Service = Service.GENERAL +) -> None: """ - Check if the vector table is valid - single vector table should contain columns DATE, REAL, vector_name. + Check if the pyarrow vector table is valid. - Expect the pyarrow table to contain the following columns: DATE, REAL, vector_name. + Expect the pyarrow single vector table to contain the following columns: DATE, REAL, vector_name. Raises InvalidDataError if the table does not contain the expected columns. """ expected_columns = {"DATE", "REAL", vector_name} - if set(vector_table.column_names) != expected_columns: - unexpected_columns = set(vector_table.column_names) - expected_columns + actual_columns = set(vector_table.column_names) + if actual_columns != expected_columns: + unexpected_columns = actual_columns - expected_columns raise InvalidDataError(f"Unexpected columns in table {unexpected_columns}", service)