diff --git a/CHANGELOG.md b/CHANGELOG.md index c0f652716c..7b1ab95851 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -245,6 +245,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). #### Added - `inflows` and `outflows` exposed via API endpoint + added to flowclient [#2029](https://github.com/Flowminder/FlowKit/issues/2029), [#4866](https://github.com/Flowminder/FlowKit/issues/4866) +- Added a new `@pre_flight` decorator which `Query` subclasses may use to indicate a method should be run to confirm the query is runnable +- Added a new `preflight()` method to query, which calls all applicable pre-flight check methods for the query ### Changed - __Action Needed__ Airflow updated to version 2.3.3; **backup flowetl_db before applying update** [#4940](https://github.com/Flowminder/FlowKit/pull/4940) @@ -255,6 +257,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - FlowDB now always creates a role named `flowmachine.` - Flowmachine will set the state of a query being stored to cancelled if interrupted while the store is running. - Flowmachine now supports sqlalchemy >=1.4 [#5140](https://github.com/Flowminder/FlowKit/issues/5140) +- `get_cached_query_objects_ordered_by_score` is now a generator. [#3116](https://github.com/Flowminder/FlowKit/issues/3116) +- Queries should no longer require communication with the database during `__init__`, any checks that require database access must now be implemented as a method of the class and use the `@pre_flight` decorator +- When specifying tables in Flowmachine, the `events.` prefix is no longer required. ### Fixed - Flowmachine now makes the built in `flowmachine` role owner of cache tables as a post-action when a query is `store`d. [#4714](https://github.com/Flowminder/FlowKit/issues/4714) @@ -265,6 +270,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Removed - `use_file_flux_sensor` removed entirely. [#2812](https://github.com/Flowminder/FlowKit/issues/2812) - `Model`, `ModelResult` and `Louvain` have been removed. [#5168](https://github.com/Flowminder/FlowKit/issues/5168) +- `Table` no longer automatically infers columns from the database, they must be specified. ## [1.16.0] diff --git a/docs/source/analyst/advanced_usage/worked_examples/mobile-data-usage.ipynb b/docs/source/analyst/advanced_usage/worked_examples/mobile-data-usage.ipynb index 6f07c89a48..5a22a2baaf 100644 --- a/docs/source/analyst/advanced_usage/worked_examples/mobile-data-usage.ipynb +++ b/docs/source/analyst/advanced_usage/worked_examples/mobile-data-usage.ipynb @@ -92,7 +92,7 @@ "data_events_query = flowmachine.features.TotalLocationEvents(\n", " start=\"2016-01-01\",\n", " stop=\"2016-01-08\",\n", - " table=\"events.mds\",\n", + " table=\"mds\",\n", " spatial_unit=make_spatial_unit(\"versioned-cell\"),\n", " interval=\"hour\",\n", ")" diff --git a/flowdb/bin/build/0020_schema_cache.sql b/flowdb/bin/build/0020_schema_cache.sql index 0b3bf37bc4..980ba91461 100644 --- a/flowdb/bin/build/0020_schema_cache.sql +++ b/flowdb/bin/build/0020_schema_cache.sql @@ -48,4 +48,7 @@ CREATE TABLE IF NOT EXISTS cache.dependencies CREATE TABLE cache.cache_config (key text, value text); INSERT INTO cache.cache_config (key, value) VALUES ('half_life', NULL); INSERT INTO cache.cache_config (key, value) VALUES ('cache_size', NULL); -INSERT INTO cache.cache_config (key, value) VALUES ('cache_protected_period', NULL); \ No newline at end of file +INSERT INTO cache.cache_config (key, value) VALUES ('cache_protected_period', NULL); + +CREATE TABLE cache.zero_cache (object_class text); +INSERT INTO cache.zero_cache (object_class) VALUES ('Table'), ('GeoTable'), ('CallsTable'), ('SmsTable'), ('MdsTable'), ('TopupsTable'), ('ForwardsTable'), ('TacsTable'), ('CellsTable'), ('SitesTable'); \ No newline at end of file diff --git a/flowdb/bin/build/0030_utilities.sql b/flowdb/bin/build/0030_utilities.sql index a89fe4d023..98217d518e 100644 --- a/flowdb/bin/build/0030_utilities.sql +++ b/flowdb/bin/build/0030_utilities.sql @@ -243,9 +243,10 @@ $$ DECLARE score float; BEGIN UPDATE cache.cached SET last_accessed = NOW(), access_count = access_count + 1, - cache_score_multiplier = CASE WHEN class='Table' THEN 0 ELSE + cache_score_multiplier = CASE WHEN class=ANY(no_score.classes) THEN 0 ELSE cache_score_multiplier+POWER(1 + ln(2) / cache_half_life(), nextval('cache.cache_touches') - 2) END + FROM (SELECT array_agg(object_class) as classes FROM cache.zero_cache) AS no_score WHERE query_id=cached_query_id RETURNING cache_score(cache_score_multiplier, compute_time, greatest(table_size(tablename, schema), 0.00001)) INTO score; IF NOT FOUND THEN RAISE EXCEPTION 'Cache record % not found', cached_query_id; diff --git a/flowmachine/flowmachine/core/cache.py b/flowmachine/flowmachine/core/cache.py index 12962c9010..5fa80b60f5 100644 --- a/flowmachine/flowmachine/core/cache.py +++ b/flowmachine/flowmachine/core/cache.py @@ -12,7 +12,7 @@ import sqlalchemy.engine from contextvars import copy_context from concurrent.futures import Executor, TimeoutError -from functools import partial +from functools import partial, lru_cache from sqlalchemy.exc import ResourceClosedError from typing import TYPE_CHECKING, Tuple, List, Callable, Optional @@ -29,6 +29,7 @@ from flowmachine.core.query_state import QueryStateMachine, QueryEvent from flowmachine import __version__ + if TYPE_CHECKING: from .query import Query from .connection import Connection @@ -191,8 +192,10 @@ def write_query_to_cache( if this_thread_is_owner: logger.debug(f"In charge of executing '{query.query_id}'.") try: + query.preflight() query_ddl_ops = ddl_ops_func(name, schema) except Exception as exc: + q_state_machine.raise_error() logger.error(f"Error generating SQL. Error was {exc}") raise exc logger.debug("Made SQL.") @@ -204,6 +207,7 @@ def write_query_to_cache( ) logger.debug("Executed queries.") except Exception as exc: + q_state_machine.raise_error() logger.error(f"Error executing SQL. Error was {exc}") raise exc if analyze: @@ -219,6 +223,7 @@ def write_query_to_cache( executed_sql=";\n".join(query_ddl_ops), ) except Exception as exc: + q_state_machine.raise_error() logger.error(f"Error writing cache metadata. Error was {exc}") raise exc q_state_machine.finish() @@ -229,7 +234,6 @@ def write_query_to_cache( finally: if this_thread_is_owner and not q_state_machine.is_finished_executing: q_state_machine.cancel() - q_state_machine.wait_until_complete(sleep_duration=sleep_duration) if q_state_machine.is_completed: return query @@ -301,6 +305,7 @@ def write_cache_metadata( psycopg2.Binary(self_storage), ), ) + logger.debug("Touching cache.", query_id=query.query_id, query=str(query)) connection.exec_driver_sql( "SELECT touch_cache(%(ident)s);", dict(ident=query.query_id) ) @@ -334,6 +339,7 @@ def touch_cache(connection: "Connection", query_id: str) -> float: The new cache score """ try: + logger.debug("Touching cache.", query_id=query_id) with connection.engine.begin() as trans: return float( trans.exec_driver_sql(f"SELECT touch_cache('{query_id}')").fetchall()[ @@ -481,6 +487,19 @@ def get_query_object_by_id(connection: "Connection", query_id: str) -> "Query": raise ValueError(f"Query id '{query_id}' is not in cache on this connection.") +@lru_cache(maxsize=1) +def _get_protected_classes(): + from flowmachine.core.events_table import events_table_map + from flowmachine.core.infrastructure_table import infrastructure_table_map + + return [ + "Table", + "GeoTable", + *[cls.__name__ for cls in events_table_map.values()], + *[cls.__name__ for cls in infrastructure_table_map.values()], + ] + + def get_cached_query_objects_ordered_by_score( connection: "Connection", protected_period: Optional[int] = None, @@ -502,6 +521,7 @@ def get_cached_query_objects_ordered_by_score( Returns a list of cached Query objects with their on disk sizes """ + protected_period_clause = ( (f" AND NOW()-created > INTERVAL '{protected_period} seconds'") if protected_period is not None @@ -509,7 +529,7 @@ def get_cached_query_objects_ordered_by_score( ) qry = f"""SELECT query_id, table_size(tablename, schema) as table_size FROM cache.cached - WHERE cached.class!='Table' AND cached.class!='GeoTable' + WHERE NOT (cached.class=ANY(ARRAY{_get_protected_classes()})) {protected_period_clause} ORDER BY cache_score(cache_score_multiplier, compute_time, table_size(tablename, schema)) ASC """ @@ -689,9 +709,9 @@ def get_size_of_cache(connection: "Connection") -> int: Number of bytes in total used by cache tables """ - sql = """SELECT sum(table_size(tablename, schema)) as total_bytes + sql = f"""SELECT sum(table_size(tablename, schema)) as total_bytes FROM cache.cached - WHERE cached.class!='Table' AND cached.class!='GeoTable'""" + WHERE NOT (cached.class=ANY(ARRAY{_get_protected_classes()}))""" cache_bytes = connection.fetch(sql)[0][0] return 0 if cache_bytes is None else int(cache_bytes) diff --git a/flowmachine/flowmachine/core/errors/flowmachine_errors.py b/flowmachine/flowmachine/core/errors/flowmachine_errors.py index 60ce8fccb9..41c3605f1d 100644 --- a/flowmachine/flowmachine/core/errors/flowmachine_errors.py +++ b/flowmachine/flowmachine/core/errors/flowmachine_errors.py @@ -6,6 +6,27 @@ """ Custom errors raised by flowmachine. """ +from typing import List, Dict + + +class PreFlightFailedException(Exception): + """ + Exception indicating that preflight checks for a query failed. + + Parameters + ---------- + query_id : str + Identifier of the query + errors : dict + Mapping from query reps to lists of exceptions raised in preflight + """ + + def __init__(self, query_id: str, errors: Dict[str, List[Exception]]): + self.errors = errors + self.query_id = query_id + Exception.__init__( + self, f"Pre-flight failed for '{self.query_id}'. Errors: {errors}" + ) class StoreFailedException(Exception): diff --git a/flowmachine/flowmachine/core/events_table.py b/flowmachine/flowmachine/core/events_table.py new file mode 100644 index 0000000000..f1f5543dc8 --- /dev/null +++ b/flowmachine/flowmachine/core/events_table.py @@ -0,0 +1,125 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import Optional + +from flowmachine.core.flowdb_table import FlowDBTable + + +class EventsTable(FlowDBTable): + def __init__(self, *, name, columns: Optional[list[str]] = None) -> None: + super().__init__(schema="events", name=name, columns=columns) + + +class CallsTable(EventsTable): + all_columns = [ + "id", + "outgoing", + "datetime", + "duration", + "network", + "msisdn", + "msisdn_counterpart", + "location_id", + "imsi", + "imei", + "tac", + "operator_code", + "country_code", + ] + + def __init__(self, *, columns: Optional[list[str]] = None) -> None: + super().__init__(name="calls", columns=columns) + + +class ForwardsTable(EventsTable): + all_columns = [ + "id", + "outgoing", + "datetime", + "network", + "msisdn", + "msisdn_counterpart", + "location_id", + "imsi", + "imei", + "tac", + "operator_code", + "country_code", + ] + + def __init__(self, *, columns: Optional[list[str]] = None) -> None: + super().__init__(name="forwards", columns=columns) + + +class SmsTable(EventsTable): + all_columns = [ + "id", + "outgoing", + "datetime", + "network", + "msisdn", + "msisdn_counterpart", + "location_id", + "imsi", + "imei", + "tac", + "operator_code", + "country_code", + ] + + def __init__(self, *, columns: Optional[list[str]] = None) -> None: + super().__init__(name="sms", columns=columns) + + +class MdsTable(EventsTable): + all_columns = [ + "id", + "datetime", + "duration", + "volume_total", + "volume_upload", + "volume_download", + "msisdn", + "location_id", + "imsi", + "imei", + "tac", + "operator_code", + "country_code", + ] + + def __init__(self, *, columns: Optional[list[str]] = None) -> None: + super().__init__(name="mds", columns=columns) + + +class TopupsTable(EventsTable): + all_columns = [ + "id", + "datetime", + "type", + "recharge_amount", + "airtime_fee", + "tax_and_fee", + "pre_event_balance", + "post_event_balance", + "msisdn", + "location_id", + "imsi", + "imei", + "tac", + "operator_code", + "country_code", + ] + + def __init__(self, *, columns: Optional[list[str]] = None) -> None: + super().__init__(name="topups", columns=columns) + + +events_table_map = dict( + calls=CallsTable, + sms=SmsTable, + mds=MdsTable, + topups=TopupsTable, + forwards=ForwardsTable, +) diff --git a/flowmachine/flowmachine/core/flowdb_table.py b/flowmachine/flowmachine/core/flowdb_table.py new file mode 100644 index 0000000000..d8f3000335 --- /dev/null +++ b/flowmachine/flowmachine/core/flowdb_table.py @@ -0,0 +1,34 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +from abc import ABCMeta +from typing import Optional + +from flowmachine.core.table import Table + + +class FlowDBTable(Table, metaclass=ABCMeta): + """ + Abstract base class for fixed tables that exist in FlowDB. + + Parameters + ---------- + name : str + schema : str + columns : list of str + """ + + def __init__(self, *, name: str, schema: str, columns: Optional[list[str]]) -> None: + if columns is None: + columns = self.all_columns + if set(columns).issubset(self.all_columns): + super().__init__(schema=schema, name=name, columns=columns) + else: + raise ValueError( + f"Columns {columns} must be a subset of {self.all_columns}" + ) + + @property + def all_columns(self): + raise NotImplementedError diff --git a/flowmachine/flowmachine/core/geotable.py b/flowmachine/flowmachine/core/geotable.py index 7e4f227f01..9f4dbeeb39 100644 --- a/flowmachine/flowmachine/core/geotable.py +++ b/flowmachine/flowmachine/core/geotable.py @@ -6,6 +6,7 @@ """ Simple utility class that represents tables with geometry. """ +from typing import Optional, List from . import Table from .mixins import GeoDataMixin @@ -47,19 +48,25 @@ class GeoTable(GeoDataMixin, Table): """ def __init__( - self, name=None, schema=None, columns=None, geom_column="geom", gid_column=None + self, + name: str, + *, + schema: Optional[str] = None, + columns: List[str], + geom_column: str = "geom", + gid_column: Optional[str] = None, ): self.geom_column = geom_column self.gid_column = gid_column - super().__init__(name=name, schema=schema, columns=columns) - if geom_column not in self.column_names: + if self.geom_column not in columns: raise ValueError( - "geom_column: {} is not a column in this table.".format(geom_column) + f"geom_column: {self.geom_column} is not a column in this table." ) - if gid_column is not None and gid_column not in self.column_names: + if self.gid_column is not None and self.gid_column not in columns: raise ValueError( - "gid_column: {} is not a column in this table.".format(gid_column) + f"gid_column: {self.gid_column} is not a column in this table." ) + super().__init__(name=name, schema=schema, columns=columns) def _geo_augmented_query(self): if self.gid_column is None: diff --git a/flowmachine/flowmachine/core/infrastructure_table.py b/flowmachine/flowmachine/core/infrastructure_table.py new file mode 100644 index 0000000000..0f723710fe --- /dev/null +++ b/flowmachine/flowmachine/core/infrastructure_table.py @@ -0,0 +1,125 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import Optional + +from flowmachine.core.flowdb_table import FlowDBTable + + +class InfrastructureTable(FlowDBTable): + def __init__(self, *, name: str, columns: Optional[list[str]] = None) -> None: + super().__init__(schema="infrastructure", name=name, columns=columns) + + +class SitesTable(InfrastructureTable): + all_columns = [ + "site_id", + "id", + "version", + "name", + "type", + "status", + "structure_type", + "is_cow", + "date_of_first_service", + "date_of_last_service", + "geom_point", + "geom_polygon", + ] + + def __init__(self, *, columns: Optional[list[str]] = None) -> None: + super().__init__(name="sites", columns=columns) + + +class CellsTable(InfrastructureTable): + all_columns = [ + "cell_id", + "id", + "version", + "site_id", + "name", + "type", + "msc", + "bsc_rnc", + "antenna_type", + "status", + "lac", + "height", + "azimuth", + "transmitter", + "max_range", + "min_range", + "electrical_tilt", + "mechanical_downtilt", + "date_of_first_service", + "date_of_last_service", + "geom_point", + "geom_polygon", + ] + + def __init__(self, *, columns: Optional[list[str]] = None) -> None: + super().__init__(name="cells", columns=columns) + + +class TacsTable(InfrastructureTable): + all_columns = [ + "id", + "brand", + "model", + "width", + "height", + "depth", + "weight", + "display_type", + "display_colors", + "display_width", + "display_height", + "mms_receiver", + "mms_built_in_camera", + "wap_push_ota_support", + "hardware_gprs", + "hardware_edge", + "hardware_umts", + "hardware_wifi", + "hardware_bluetooth", + "hardware_gps", + "software_os_vendor", + "software_os_name", + "software_os_version", + "wap_push_ota_settings", + "wap_push_ota_bookmarks", + "wap_push_ota_app_internet", + "wap_push_ota_app_browser", + "wap_push_ota_app_mms", + "wap_push_ota_single_shot", + "wap_push_ota_multi_shot", + "wap_push_oma_settings", + "wap_push_oma_app_internet", + "wap_push_oma_app_browser", + "wap_push_oma_app_mms", + "wap_push_oma_cp_bookmarks", + "wap_1_2_1", + "wap_2_0", + "syncml_dm_settings", + "syncml_dm_acc_gprs", + "syncml_dm_app_internet", + "syncml_dm_app_browser", + "syncml_dm_app_mms", + "syncml_dm_app_bookmark", + "syncml_dm_app_java", + "wap_push_oma_app_ims", + "wap_push_oma_app_poc", + "j2me_midp_10", + "j2me_midp_20", + "j2me_midp_21", + "j2me_cldc_10", + "j2me_cldc_11", + "j2me_cldc_20", + "hnd_type", + ] + + def __init__(self, *, columns: Optional[list[str]] = None) -> None: + super().__init__(name="tacs", columns=columns) + + +infrastructure_table_map = dict(tacs=TacsTable, cells=CellsTable, sites=SitesTable) diff --git a/flowmachine/flowmachine/core/preflight.py b/flowmachine/flowmachine/core/preflight.py new file mode 100644 index 0000000000..94b2168f00 --- /dev/null +++ b/flowmachine/flowmachine/core/preflight.py @@ -0,0 +1,96 @@ +import inspect +from collections import defaultdict + +import typing + +import networkx as nx +import structlog + +from flowmachine.core.dependency_graph import ( + get_dependency_links, + _assemble_dependency_graph, +) +from flowmachine.core.errors.flowmachine_errors import ( + PreFlightFailedException, +) + +logger = structlog.get_logger("flowmachine.debug", submodule=__name__) + + +def pre_flight(method): + method.__hooks__ = getattr(method, "__hooks__", {}) + method.__hooks__["pre_flight"] = method + return method + + +def resolve_hooks(cls) -> typing.Dict[str, typing.List[typing.Callable]]: + """Add in the decorated processors + By doing this after constructing the class, we let standard inheritance + do all the hard work. + """ + mro = inspect.getmro(cls) + + hooks = defaultdict(list) + + for attr_name in dir(cls): + # Need to look up the actual descriptor, not whatever might be + # bound to the class. This needs to come from the __dict__ of the + # declaring class. + for parent in mro: + try: + attr = parent.__dict__[attr_name] + except KeyError: + continue + else: + break + else: + # In case we didn't find the attribute and didn't break above. + # We should never hit this - it's just here for completeness + # to exclude the possibility of attr being undefined. + continue + + try: + hook_config = attr.__hooks__ + except AttributeError: + pass + else: + for key in hook_config: + # Use name here so we can get the bound method later, in + # case the processor was a descriptor or something. + hooks[key].append(attr_name) + + return hooks + + +class Preflight: + def preflight(self): + logger.debug("Starting pre-flight checks.", query=str(self)) + errors = dict() + dep_graph = _assemble_dependency_graph( + dependencies=get_dependency_links(self), + attrs_func=lambda x: dict(query=x), + ) + deps = [dep_graph.nodes[id]["query"] for id in nx.topological_sort(dep_graph)][ + ::-1 + ] + + for dependency in deps: + for hook in resolve_hooks(dependency.__class__)["pre_flight"]: + logger.debug( + "Running hook", + query=str(self), + hook=hook, + dependency=str(dependency), + ) + try: + getattr(dependency, hook)() + except Exception as e: + errors.setdefault(str(dependency), list()).append(e) + if len(errors) > 0: + logger.debug( + "Pre-flight failed.", + query=str(self), + query_id=self.query_id, + errors=errors, + ) + raise PreFlightFailedException(self.query_id, errors) diff --git a/flowmachine/flowmachine/core/query.py b/flowmachine/flowmachine/core/query.py index a1959014b0..e9275fafc6 100644 --- a/flowmachine/flowmachine/core/query.py +++ b/flowmachine/flowmachine/core/query.py @@ -29,7 +29,11 @@ get_redis, submit_to_executor, ) -from flowmachine.core.errors.flowmachine_errors import QueryResetFailedException + +from flowmachine.core.errors.flowmachine_errors import ( + QueryResetFailedException, +) +from flowmachine.core.preflight import Preflight from flowmachine.core.query_state import QueryStateMachine from abc import ABCMeta, abstractmethod @@ -55,7 +59,7 @@ MAX_POSTGRES_NAME_LENGTH = 63 -class Query(metaclass=ABCMeta): +class Query(Preflight, metaclass=ABCMeta): """ The core base class of the flowmachine module. This should handle all input and output methods for our sql queries, so that @@ -366,7 +370,8 @@ def head(self, n=5): try: return self._df.head(n) except AttributeError: - Q = f"SELECT {self.column_names_as_string_list} FROM ({self.get_query()}) h LIMIT {n};" + q_string = self.get_query() + Q = f"SELECT {self.column_names_as_string_list} FROM ({q_string}) h LIMIT {n};" con = get_db().engine with con.begin() as trans: df = pd.read_sql_query(Q, con=trans) @@ -382,7 +387,14 @@ def get_table(self): flowmachine.core.Table The stored version of this Query as a Table object """ - return flowmachine.core.Table(self.fully_qualified_table_name) + if self.is_stored: + table = flowmachine.core.Table( + self.fully_qualified_table_name, columns=self.column_names + ) + table.preflight() + return table + else: + raise ValueError(f"{self} not stored on this connection.") def union(self, *other: "Query", all: bool = True): """ @@ -524,8 +536,9 @@ def _make_sql(self, name: str, schema: Union[str, None] = None) -> List[str]: logger.info("Table already exists") return [] + q_string = self._make_query() Q = f"""EXPLAIN (ANALYZE TRUE, TIMING FALSE, FORMAT JSON) CREATE TABLE {full_name} AS - (SELECT {self.column_names_as_string_list} FROM ({self._make_query()}) _)""" + (SELECT {self.column_names_as_string_list} FROM ({q_string}) _)""" queries.append(Q) # Make flowmachine user the owner to allow server to cleanup cache tables queries.append(f"ALTER TABLE {full_name} OWNER TO flowmachine;") diff --git a/flowmachine/flowmachine/core/random.py b/flowmachine/flowmachine/core/random.py index 7996ed3734..c229d5c3c2 100644 --- a/flowmachine/flowmachine/core/random.py +++ b/flowmachine/flowmachine/core/random.py @@ -6,9 +6,10 @@ Classes to select random samples from queries or tables. """ import random -from typing import List, Optional, Dict, Any, Union, Type, Tuple +from typing import List, Optional, Dict, Any, Type, Tuple from abc import ABCMeta, abstractmethod +from .preflight import pre_flight from .query import Query from .table import Table @@ -145,18 +146,20 @@ def __init__( fraction: Optional[float] = None, estimate_count: bool = False, ): + super().__init__( + query=query, size=size, fraction=fraction, estimate_count=estimate_count + ) + + @pre_flight + def check_not_inherited(self): # Raise a value error if the query is a table, and has children, as the # method relies on it not having children. - if isinstance(query, Table) and query.has_children(): + if isinstance(self.query, Table) and self.query.has_children(): raise ValueError( "It is not possible to use the 'system_rows' method in tables with inheritance " + "as it selects a random sample for each child table and not for the set as a whole." ) - super().__init__( - query=query, size=size, fraction=fraction, estimate_count=estimate_count - ) - def _make_query(self) -> str: # TABLESAMPLE only works on tables, so silently store this query self.query.store().result() diff --git a/flowmachine/flowmachine/core/server/action_handlers.py b/flowmachine/flowmachine/core/server/action_handlers.py index 8f53e5d809..07ba567efb 100644 --- a/flowmachine/flowmachine/core/server/action_handlers.py +++ b/flowmachine/flowmachine/core/server/action_handlers.py @@ -14,6 +14,7 @@ # action handler and also gracefully handles any potential errors. # import asyncio +import structlog from contextvars import copy_context from functools import partial import json @@ -41,6 +42,10 @@ from ..connection import MissingCheckError from ..dependency_graph import query_progress +from ..errors.flowmachine_errors import PreFlightFailedException +import traceback + +logger = structlog.get_logger("flowmachine.debug", submodule=__name__) async def action_handler__ping(config: "FlowmachineServerConfig") -> ZMQReply: @@ -80,6 +85,16 @@ async def action_handler__get_query_schemas( def _load_query_object(params: dict) -> "BaseExposedQuery": try: query_obj = FlowmachineQuerySchema().load(params) + query_obj._flowmachine_query_obj.preflight() # Note that we probably want to remove this call to allow getting qid faster + except PreFlightFailedException as exc: + orig_error_msg = exc.args[0] + error_msg = ( + f"Internal flowmachine server error: could not create query object using query schema. " + f"The original error was: '{orig_error_msg}'" + ) + raise QueryLoadError( + error_msg, params, orig_error_msg=orig_error_msg, errors=exc.errors + ) except TypeError as exc: # We need to catch TypeError here, otherwise they propagate up to # perform_action() and result in a very misleading error message. @@ -198,11 +213,24 @@ async def action_handler__run_query( ), ), ) - except Exception as e: + except PreFlightFailedException as exc: + orig_error_msg = exc.args[0] + error_msg = f"Preflight failed for {exc.query_id}." + return ZMQReply( + status="error", + msg=error_msg, + payload={ + "params": action_params, + "orig_error_msg": orig_error_msg, + "errors": exc.errors, + }, + ) + except Exception as exc: + logger.error(str(exc), exception=exc, traceback=traceback.format_exc()) return ZMQReply( status="error", msg="Unable to create query object.", - payload={"exception": str(e)}, + payload={"exception": str(exc)}, ) # Register the query as "known" (so that we can later look up the query kind diff --git a/flowmachine/flowmachine/core/server/query_schemas/base_exposed_query.py b/flowmachine/flowmachine/core/server/query_schemas/base_exposed_query.py index 27117cfd3d..7ed280d5eb 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/base_exposed_query.py +++ b/flowmachine/flowmachine/core/server/query_schemas/base_exposed_query.py @@ -52,6 +52,7 @@ def store_async(self, store_dependencies=True): Query ID that can be used to check the query state. """ q = self._flowmachine_query_obj + q.preflight() q.store(store_dependencies=store_dependencies) query_id = q.query_id diff --git a/flowmachine/flowmachine/core/server/query_schemas/base_schema.py b/flowmachine/flowmachine/core/server/query_schemas/base_schema.py index ca21d6cdba..82f57837b4 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/base_schema.py +++ b/flowmachine/flowmachine/core/server/query_schemas/base_schema.py @@ -4,7 +4,7 @@ from marshmallow import Schema, post_load -from flowmachine.core import make_spatial_unit +from flowmachine.core import make_spatial_unit, Table class BaseSchema(Schema): @@ -22,7 +22,13 @@ def remove_query_kind_if_present_and_load(self, params, **kwargs): elif "lon-lat" in aggregation_unit_string: spatial_unit_args = { "spatial_unit_type": "lon-lat", - "geom_table": geom_table, + "geom_table": ( + None + if geom_table is None + else Table( + geom_table, columns=[geom_table_join_on, "geom_point"] + ) + ), "geom_table_join_on": geom_table_join_on, } else: diff --git a/flowmachine/flowmachine/core/server/query_schemas/custom_fields.py b/flowmachine/flowmachine/core/server/query_schemas/custom_fields.py index bbfcfe35be..1a0273301e 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/custom_fields.py +++ b/flowmachine/flowmachine/core/server/query_schemas/custom_fields.py @@ -56,7 +56,7 @@ class EventTypes(fields.List): """ A list of strings representing an event type, for example "calls", "sms", "mds", "topups". - When deserialised, will be deduped, and prefixed with "events." + When deserialised, will be deduped. """ def __init__( @@ -77,13 +77,6 @@ def __init__( **kwargs, ) - def _deserialize(self, value, attr, data, **kwargs): - # Temporary workaround for https://github.com/Flowminder/FlowKit/issues/1015 until underlying issue resolved - return [ - f"events.{event_type}" - for event_type in set(super()._deserialize(value, attr, data, **kwargs)) - ] - class TotalBy(fields.String): """ diff --git a/flowmachine/flowmachine/core/spatial_unit.py b/flowmachine/flowmachine/core/spatial_unit.py index a33a5c16a6..e2103b943f 100644 --- a/flowmachine/flowmachine/core/spatial_unit.py +++ b/flowmachine/flowmachine/core/spatial_unit.py @@ -13,12 +13,14 @@ from flowmachine.utils import get_name_and_alias from flowmachine.core.errors import InvalidSpatialUnitError from flowmachine.core import Query, Table +from flowmachine.core.preflight import pre_flight from flowmachine.core.context import get_db from flowmachine.core.grid import Grid # TODO: Currently most spatial units require a FlowDB connection at init time. # It would be useful to remove this requirement wherever possible, and instead # implement a method to check whether the required data can be found in the DB. +from .infrastructure_table import infrastructure_table_map def _substitute_lat_lon(location_dict): @@ -278,7 +280,7 @@ def __init__( *, geom_table_column_names: Union[str, Iterable[str]], location_id_column_names: Union[str, Iterable[str]], - geom_table: Optional[Union[Query, str]] = None, + geom_table: Optional[Query] = None, mapping_table: Optional[Union[Query, str]] = None, geom_column: str = "geom", geom_table_join_on: Optional[str] = None, @@ -306,11 +308,11 @@ def __init__( if geom_table is None: # Creating a Table object here means that we don't have to handle # tables and Query objects differently in _make_query and get_geom_query - self.geom_table = Table(name=get_db().location_table) + self.geom_table = infrastructure_table_map["cells"]() elif isinstance(geom_table, Query): self.geom_table = geom_table else: - self.geom_table = Table(name=geom_table) + raise TypeError("geom_table must be a Query or Table object.") if mapping_table is not None: # Creating a Table object here means that we don't have to handle @@ -318,7 +320,10 @@ def __init__( if isinstance(mapping_table, Query): self.mapping_table = mapping_table else: - self.mapping_table = Table(name=mapping_table) + self.mapping_table = Table( + name=mapping_table, + columns=[location_table_join_on, geom_table_join_on], + ) if location_table_join_on not in self.mapping_table.column_names: raise ValueError( @@ -656,8 +661,8 @@ def __init__( self, *, geom_table_column_names: Union[str, Iterable[str]], - geom_table: Union[Query, str], - mapping_table: Optional[Union[Query, str]] = None, + geom_table: Query, + mapping_table: Optional[Query] = None, geom_column: str = "geom", geom_table_join_on: Optional[str] = None, ): @@ -702,15 +707,18 @@ class VersionedCellSpatialUnit(LonLatSpatialUnit): """ def __init__(self) -> None: - if get_db().location_table != "infrastructure.cells": - raise InvalidSpatialUnitError("Versioned cell spatial unit is unavailable.") super().__init__( geom_table_column_names=["version"], location_id_column_names=["location_id", "version"], - geom_table="infrastructure.cells", + geom_table=infrastructure_table_map["cells"](), ) + @pre_flight + def check_cells_available(self): + if get_db().location_table != "infrastructure.cells": + raise InvalidSpatialUnitError("Versioned cell spatial unit is unavailable.") + @property def canonical_name(self) -> str: return "versioned-cell" @@ -731,7 +739,7 @@ def __init__(self) -> None: "version", ], location_id_column_names=["site_id", "version"], - geom_table="infrastructure.sites", + geom_table=infrastructure_table_map["sites"](), geom_table_join_on="id", location_table_join_on="site_id", ) @@ -780,12 +788,15 @@ def __init__( col_name = f"admin{level}pcod AS pcod" else: col_name = region_id_column_name - table = f"geography.admin{level}" self.level = level super().__init__( geom_table_column_names=col_name, - geom_table=table, + geom_table=Table( + schema="geography", + name=f"admin{level}", + columns=[f"admin{level}pcod", f"admin{level}name", "geom"], + ), mapping_table=mapping_table, geom_table_join_on=None if mapping_table is None else f"admin{level}pcod", ) @@ -837,7 +848,7 @@ def make_spatial_unit( level: Optional[int] = None, region_id_column_name: Optional[Union[str, Iterable[str]]] = None, size: Union[float, int] = None, - geom_table: Optional[Union[Query, str]] = None, + geom_table: Optional[Query] = None, geom_column: str = "geom", mapping_table: Optional[Union[str, Query]] = None, geom_table_join_on: Optional[str] = None, diff --git a/flowmachine/flowmachine/core/sqlalchemy_utils.py b/flowmachine/flowmachine/core/sqlalchemy_utils.py index f5cf455d4c..6f0991945c 100644 --- a/flowmachine/flowmachine/core/sqlalchemy_utils.py +++ b/flowmachine/flowmachine/core/sqlalchemy_utils.py @@ -56,7 +56,7 @@ def get_sql_string(sqlalchemy_query): return sql -def get_string_representation(sqlalchemy_expr, engine=None): +def get_string_representation(sqlalchemy_expr): """ Return a string containing a SQL fragment which is compiled from the given sqlalchemy expression. @@ -66,7 +66,11 @@ def get_string_representation(sqlalchemy_expr, engine=None): String representation of the sqlalchemy expression. """ # assert isinstance(sqlalchemy_expr, ColumnElement) - return str(sqlalchemy_expr.compile(engine, compile_kwargs={"literal_binds": True})) + return str( + sqlalchemy_expr.compile( + dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True} + ) + ) def get_query_result_as_dataframe(query, *, engine): diff --git a/flowmachine/flowmachine/core/table.py b/flowmachine/flowmachine/core/table.py index 89d8d0f697..db03cf3968 100644 --- a/flowmachine/flowmachine/core/table.py +++ b/flowmachine/flowmachine/core/table.py @@ -9,12 +9,12 @@ """ from typing import List, Iterable, Optional -from flowmachine.core.query_state import QueryStateMachine -from .context import get_db, get_redis -from .errors import NotConnectedError -from .query import Query -from .subset import subset_factory -from .cache import write_cache_metadata +from flowmachine.core.query_state import QueryStateMachine, QueryState +from flowmachine.core.context import get_db, get_redis +from flowmachine.core.preflight import pre_flight +from flowmachine.core.query import Query +from flowmachine.core.subset import subset_factory +from flowmachine.core.cache import write_cache_metadata import structlog @@ -39,7 +39,7 @@ class Table(Query): Examples -------- - >>> t = Table(name="calls", schema="events") + >>> t = Table(name="calls", schema="events", columns=["id", "outgoing", "datetime", "duration"]) >>> t.head() id outgoing datetime duration \ 0 5wNJA-PdRJ4-jxEdG-yOXpZ True 2016-01-01 22:38:06+00:00 3393.0 @@ -63,7 +63,8 @@ class Table(Query): def __init__( self, - name: Optional[str] = None, + name: str, + *, schema: Optional[str] = None, columns: Optional[Iterable[str]] = None, ): @@ -78,56 +79,61 @@ def __init__( self.name = name self.schema = schema - self.fqn = "{}.{}".format(schema, name) if schema else name - if "." not in self.fqn: - raise ValueError("{} is not a valid table.".format(self.fqn)) + self.fqn = f"{schema}.{name}" if schema else name + + # Record provided columns to ensure that query_id differs with different columns + if isinstance(columns, str): # Wrap strings in a list + columns = [columns] + self.columns = columns + if self.columns is None or len(self.columns) == 0: + raise ValueError("No columns requested.") + super().__init__() + + @pre_flight + def check_exists(self): if not self.is_stored: - raise ValueError("{} is not a known table.".format(self.fqn)) + raise ValueError(f"{self.fqn} is not a known table.") + @pre_flight + def check_columns(self): # Get actual columns of this table from the database db_columns = list( zip( *get_db().fetch( f"""SELECT column_name from INFORMATION_SCHEMA.COLUMNS - WHERE table_name = '{self.name}' AND table_schema='{self.schema}'""" + WHERE table_name = '{self.name}' AND table_schema='{self.schema}'""" ) ) )[0] - if ( - columns is None or columns == [] - ): # No columns specified, setting them from the database - columns = db_columns - else: - self.parent_table = Table( - schema=self.schema, name=self.name - ) # Point to the full table - if isinstance(columns, str): # Wrap strings in a list - columns = [columns] - logger.debug( - "Checking provided columns against db columns.", - provided=columns, - db_columns=db_columns, + + logger.debug( + "Checking provided columns against db columns.", + provided=self.columns, + db_columns=db_columns, + ) + if not set(self.columns).issubset(db_columns): + raise ValueError( + f"{set(self.columns).difference(db_columns)} are not columns of {self.fqn}" ) - if not set(columns).issubset(db_columns): - raise ValueError( - "{} are not columns of {}".format( - set(columns).difference(db_columns), self.fqn - ) - ) - # Record provided columns to ensure that query_id differs with different columns - self.columns = columns - super().__init__() + @pre_flight + def ff_state_machine(self): # Table is immediately in a 'finished executing' state q_state_machine = QueryStateMachine( get_redis(), self.query_id, get_db().conn_id ) if not q_state_machine.is_completed: - q_state_machine.enqueue() - q_state_machine.execute() - with get_db().engine.begin() as trans: - write_cache_metadata(trans, self, compute_time=0) - q_state_machine.finish() + state, succeeded = q_state_machine.enqueue() + state, succeeded = q_state_machine.execute() + state, succeeded = q_state_machine.finish() + if succeeded: + with get_db().engine.begin() as trans: + write_cache_metadata(trans, self, compute_time=0) + state, succeeded = q_state_machine.finish() + if state != QueryState.COMPLETED: + raise RuntimeError( + f"Couldn't fast forward state machine for table {self}. State is: {state}" + ) def __format__(self, fmt): return f"" diff --git a/flowmachine/flowmachine/features/network/total_network_objects.py b/flowmachine/flowmachine/features/network/total_network_objects.py index 751358d11b..d8bd922a64 100644 --- a/flowmachine/flowmachine/features/network/total_network_objects.py +++ b/flowmachine/flowmachine/features/network/total_network_objects.py @@ -10,7 +10,7 @@ """ - +import datetime from typing import List, Optional, Tuple, Union from ...core.context import get_db @@ -66,8 +66,8 @@ class TotalNetworkObjects(GeoDataMixin, Query): def __init__( self, - start=None, - stop=None, + start: Union[str, datetime.date, datetime.datetime], + stop: Union[str, datetime.date, datetime.datetime], *, table="all", total_by="day", @@ -77,18 +77,12 @@ def __init__( subscriber_subset=None, subscriber_identifier="msisdn", ): - self.start = standardise_date( - get_db().min_date(table=table) if start is None else start - ) - self.stop = standardise_date( - get_db().max_date(table=table) if stop is None else stop - ) + self.start = standardise_date(start) + self.stop = standardise_date(stop) self.table = table if isinstance(self.table, str): self.table = self.table.lower() - if self.table != "all" and not self.table.startswith("events"): - self.table = "events.{}".format(self.table) network_object.verify_criterion("is_network_object") self.network_object = network_object diff --git a/flowmachine/flowmachine/features/spatial/versioned_infrastructure.py b/flowmachine/flowmachine/features/spatial/versioned_infrastructure.py index d9318296f7..12f5222df2 100644 --- a/flowmachine/flowmachine/features/spatial/versioned_infrastructure.py +++ b/flowmachine/flowmachine/features/spatial/versioned_infrastructure.py @@ -14,6 +14,7 @@ from datetime import datetime from flowmachine.core import Table +from ...core.infrastructure_table import infrastructure_table_map from ...core.query import Query @@ -62,7 +63,7 @@ def __init__(self, table="sites", date=None): if date == None: date = datetime.now().strftime("%Y-%m-%d") - self.table = Table(schema="infrastructure", name=table) + self.table = infrastructure_table_map[table]() self.date = date super().__init__() diff --git a/flowmachine/flowmachine/features/subscriber/active_subscribers.py b/flowmachine/flowmachine/features/subscriber/active_subscribers.py index 2eeeab1159..21506a2372 100644 --- a/flowmachine/flowmachine/features/subscriber/active_subscribers.py +++ b/flowmachine/flowmachine/features/subscriber/active_subscribers.py @@ -80,7 +80,7 @@ class ActiveSubscribers(ExposedDatetimeMixin, Query): total_major_periods=4, minor_period_threshold=1, major_period_threshold=3, - tables=["events.calls"], + tables=["calls"], ) Returns subscribers that were active in at least two ten minute intervals within half an hour, @@ -94,7 +94,7 @@ class ActiveSubscribers(ExposedDatetimeMixin, Query): minor_period_threshold=2, major_period_threshold=3, period_unit="minutes", - tables=["events.calls"], + tables=["calls"], ) diff --git a/flowmachine/flowmachine/features/subscriber/contact_balance.py b/flowmachine/flowmachine/features/subscriber/contact_balance.py index 693b9c5d9a..7ea3960331 100644 --- a/flowmachine/flowmachine/features/subscriber/contact_balance.py +++ b/flowmachine/flowmachine/features/subscriber/contact_balance.py @@ -76,14 +76,17 @@ def __init__( exclude_self_calls=True, subscriber_subset=None, ): - self.tables = tables + self.tables = ( + ["calls", "sms"] + if (isinstance(tables, str) and tables.lower() == "all") or tables is None + else tables + ) self.start = standardise_date(start) self.stop = standardise_date(stop) self.hours = hours self.direction = Direction(direction) self.subscriber_identifier = subscriber_identifier self.exclude_self_calls = exclude_self_calls - self.tables = tables column_list = [ self.subscriber_identifier, diff --git a/flowmachine/flowmachine/features/subscriber/mds_volume.py b/flowmachine/flowmachine/features/subscriber/mds_volume.py index eef640f80a..c28c58a4a5 100644 --- a/flowmachine/flowmachine/features/subscriber/mds_volume.py +++ b/flowmachine/flowmachine/features/subscriber/mds_volume.py @@ -68,7 +68,7 @@ def __init__( self.hours = hours self.volume = volume self.statistic = Statistic(statistic.lower()) - self.tables = "events.mds" + self.tables = "mds" if self.volume not in {"total", "upload", "download"}: raise ValueError(f"{self.volume} is not a valid volume.") diff --git a/flowmachine/flowmachine/features/subscriber/subscriber_call_durations.py b/flowmachine/flowmachine/features/subscriber/subscriber_call_durations.py index 5a33687d84..3c15acbee3 100644 --- a/flowmachine/flowmachine/features/subscriber/subscriber_call_durations.py +++ b/flowmachine/flowmachine/features/subscriber/subscriber_call_durations.py @@ -87,7 +87,7 @@ def __init__( self.unioned_query = EventsTablesUnion( self.start, self.stop, - tables="events.calls", + tables="calls", columns=column_list, hours=hours, subscriber_subset=subscriber_subset, @@ -186,7 +186,7 @@ def __init__( EventsTablesUnion( self.start, self.stop, - tables="events.calls", + tables="calls", columns=column_list, hours=hours, subscriber_subset=subscriber_subset, @@ -276,7 +276,7 @@ def __init__( self.unioned_query = EventsTablesUnion( self.start, self.stop, - tables="events.calls", + tables="calls", columns=column_list, hours=hours, subscriber_subset=subscriber_subset, @@ -377,7 +377,7 @@ def __init__( EventsTablesUnion( self.start, self.stop, - tables="events.calls", + tables="calls", columns=column_list, hours=hours, subscriber_subset=subscriber_subset, diff --git a/flowmachine/flowmachine/features/subscriber/subscriber_tacs.py b/flowmachine/flowmachine/features/subscriber/subscriber_tacs.py index 09e5787f35..6abf671c48 100644 --- a/flowmachine/flowmachine/features/subscriber/subscriber_tacs.py +++ b/flowmachine/flowmachine/features/subscriber/subscriber_tacs.py @@ -17,6 +17,7 @@ from .metaclasses import SubscriberFeature from ...core import Table from flowmachine.utils import standardise_date +from ...core.infrastructure_table import TacsTable valid_characteristics = { "brand", @@ -312,7 +313,7 @@ def __init__( subscriber_identifier=subscriber_identifier, subscriber_subset=subscriber_subset, ) - self.tacs = Table("infrastructure.tacs") + self.tacs = TacsTable() self.joined = self.subscriber_tacs.join(self.tacs, "tac", "id", how="left") super().__init__() @@ -393,8 +394,7 @@ def __init__( method=method, subscriber_subset=subscriber_subset, ) - self.method = method - self.tacs = Table("infrastructure.tacs") + self.tacs = TacsTable() self.joined = self.subscriber_tac.join(self.tacs, "tac", "id", how="left") super().__init__() diff --git a/flowmachine/flowmachine/features/subscriber/topup_amount.py b/flowmachine/flowmachine/features/subscriber/topup_amount.py index 2b9d7967e1..2a02293b11 100644 --- a/flowmachine/flowmachine/features/subscriber/topup_amount.py +++ b/flowmachine/flowmachine/features/subscriber/topup_amount.py @@ -66,7 +66,7 @@ def __init__( self.subscriber_identifier = subscriber_identifier self.hours = hours self.statistic = Statistic(statistic.lower()) - self.tables = "events.topups" + self.tables = "topups" column_list = [self.subscriber_identifier, "recharge_amount"] diff --git a/flowmachine/flowmachine/features/subscriber/topup_balance.py b/flowmachine/flowmachine/features/subscriber/topup_balance.py index f3093e2a8d..8df36ec5e0 100644 --- a/flowmachine/flowmachine/features/subscriber/topup_balance.py +++ b/flowmachine/flowmachine/features/subscriber/topup_balance.py @@ -91,7 +91,7 @@ def __init__( self.subscriber_identifier = subscriber_identifier self.hours = hours self.statistic = Statistic(statistic.lower()) - self.tables = "events.topups" + self.tables = "topups" column_list = [ self.subscriber_identifier, diff --git a/flowmachine/flowmachine/features/utilities/event_table_subset.py b/flowmachine/flowmachine/features/utilities/event_table_subset.py index 5ab3f8d723..0dd47fc753 100644 --- a/flowmachine/flowmachine/features/utilities/event_table_subset.py +++ b/flowmachine/flowmachine/features/utilities/event_table_subset.py @@ -3,16 +3,16 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. -import datetime -import pandas as pd import warnings from sqlalchemy import select from typing import List, Optional, Tuple -from ...core import Query, Table -from ...core.context import get_db -from ...core.errors import MissingDateError -from ...core.sqlalchemy_utils import ( +from flowmachine.core.query import Query +from flowmachine.core.context import get_db +from flowmachine.core.errors import MissingDateError +from flowmachine.core.events_table import events_table_map +from flowmachine.core.preflight import pre_flight +from flowmachine.core.sqlalchemy_utils import ( get_sqlalchemy_table_definition, make_sqlalchemy_column_from_flowmachine_column_description, get_sql_string, @@ -32,19 +32,15 @@ class EventTableSubset(Query): Parameters ---------- - start : str, default None - iso format date range for the beginning of the time frame, e.g. - 2016-01-01 or 2016-01-01 14:03:01. If None, it will use the - earliest date seen in the `events.calls` table. - stop : str, default None - As above. If None, it will use the latest date seen in the - `events.calls` table. - hours : tuple of ints, default None + start, stop : str + iso format date range for the beginning and end of the time frame, e.g. + 2016-01-01 or 2016-01-01 14:03:01. + hours : tuple of ints, default 'None' Subset the result within certain hours, e.g. (4,17) This will subset the query only with these hours, but across all specified days. Or set to 'all' to include all hours. - table : str, default 'events.calls' + table : str, default 'calls' schema qualified name of the table which the analysis is based upon subscriber_identifier : {'msisdn', 'imei'}, default 'msisdn' @@ -72,13 +68,13 @@ class EventTableSubset(Query): def __init__( self, *, - start=None, - stop=None, + start, + stop, hours: Optional[Tuple[int, int]] = None, hour_slices=None, - table="events.calls", + table="calls", subscriber_subset=None, - columns=["*"], + columns=None, subscriber_identifier="msisdn", ): if hours == "all": @@ -121,12 +117,9 @@ def __init__( self.hours = hours self.subscriber_subsetter = make_subscriber_subsetter(subscriber_subset) self.subscriber_identifier = subscriber_identifier.lower() - if columns == ["*"]: - self.table_ORIG = Table(table) - columns = self.table_ORIG.column_names - else: - self.table_ORIG = Table(table, columns=columns) - self.columns = set(columns) + self.table_ORIG = events_table_map[table](columns=columns) + + self.columns = set(self.table_ORIG.column_names) try: self.columns.remove(subscriber_identifier) self.columns.add(f"{subscriber_identifier} AS subscriber") @@ -140,57 +133,31 @@ def __init__( ) self.columns = sorted(self.columns) - self.sqlalchemy_table = get_sqlalchemy_table_definition( - self.table_ORIG.fully_qualified_table_name, - engine=get_db().engine, - ) - if self.start == self.stop: raise ValueError("Start and stop are the same.") super().__init__() - # This needs to happen after the parent classes init method has been - # called as it relies upon the connection object existing - self._check_dates() - @property def column_names(self) -> List[str]: return [c.split(" AS ")[-1] for c in self.columns] - def _check_dates(self): + @pre_flight + def check_dates(self): + logger.debug("Checking dates are valid.") # Handle the logic for dealing with missing dates. # If there are no dates present, then we raise an error # if some are present, but some are missing we raise a # warning. - # If the subscriber does not pass a start or stop date, then we take - # the min/max date in the events.calls table - if self.start is None: - d1 = ( - get_db() - .min_date(self.table_ORIG.fully_qualified_table_name.split(".")[1]) - .strftime("%Y-%m-%d %H:%M:%S") - ) - else: - d1 = self.start.split()[0] - - if self.stop is None: - d2 = ( - get_db() - .max_date(self.table_ORIG.fully_qualified_table_name.split(".")[1]) - .strftime("%Y-%m-%d %H:%M:%S") - ) - else: - d2 = self.stop.split()[0] + d1 = self.start.split()[0] + d2 = self.stop.split()[0] all_dates = list_of_dates(d1, d2) # Slightly annoying feature, but if the subscriber passes a date such as '2016-01-02' # this will be interpreted as midnight, so we don't want to include this in our # calculations. Check for this here, an if this is the case pop the final element # of the list - if (self.stop is not None) and ( - len(self.stop) == 10 or self.stop.endswith("00:00:00") - ): + if len(self.stop) == 10 or self.stop.endswith("00:00:00"): all_dates.pop(-1) # This will be a true false list for whether each of the dates # is present in the database @@ -216,26 +183,26 @@ def _check_dates(self): stacklevel=2, ) - def _make_query_with_sqlalchemy(self): + def _make_query(self): + sqlalchemy_table = get_sqlalchemy_table_definition( + self.table_ORIG.fully_qualified_table_name, + engine=get_db().engine, + ) sqlalchemy_columns = [ make_sqlalchemy_column_from_flowmachine_column_description( - self.sqlalchemy_table, column_str + sqlalchemy_table, column_str ) for column_str in self.columns ] select_stmt = select(*sqlalchemy_columns) if self.start is not None: - select_stmt = select_stmt.where( - self.sqlalchemy_table.c.datetime >= self.start - ) + select_stmt = select_stmt.where(sqlalchemy_table.c.datetime >= self.start) if self.stop is not None: - select_stmt = select_stmt.where( - self.sqlalchemy_table.c.datetime < self.stop - ) + select_stmt = select_stmt.where(sqlalchemy_table.c.datetime < self.stop) select_stmt = select_stmt.where( - self.hour_slices.get_subsetting_condition(self.sqlalchemy_table.c.datetime) + self.hour_slices.get_subsetting_condition(sqlalchemy_table.c.datetime) ) select_stmt = self.subscriber_subsetter.apply_subset_if_needed( select_stmt, subscriber_identifier=self.subscriber_identifier @@ -243,8 +210,6 @@ def _make_query_with_sqlalchemy(self): return get_sql_string(select_stmt) - _make_query = _make_query_with_sqlalchemy - @property def fully_qualified_table_name(self): # EventTableSubset are a simple select from events, and should not be cached diff --git a/flowmachine/flowmachine/features/utilities/events_tables_union.py b/flowmachine/flowmachine/features/utilities/events_tables_union.py index 8538a30e46..cf51c7cd13 100644 --- a/flowmachine/flowmachine/features/utilities/events_tables_union.py +++ b/flowmachine/flowmachine/features/utilities/events_tables_union.py @@ -1,14 +1,14 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. +import datetime import structlog import warnings -from typing import List, Optional, Tuple +from typing import List, Union, Optional, Tuple -from ...core import Query -from ...core.context import get_db -from ...core.errors import MissingDateError +from flowmachine.core import Query +from flowmachine.core.errors import MissingDateError from .event_table_subset import EventTableSubset from flowmachine.utils import standardise_date @@ -43,8 +43,8 @@ class EventsTablesUnion(Query): def __init__( self, - start, - stop, + start: Union[str, datetime.date, datetime.datetime], + stop: Union[str, datetime.date, datetime.datetime], *, columns, tables=None, @@ -62,12 +62,16 @@ def __init__( self.start = standardise_date(start) self.stop = standardise_date(stop) self.columns = columns - self.tables = self._parse_tables(tables) + self.tables = _parse_tables(tables) if "*" in columns and len(self.tables) != 1: raise ValueError( - "Must give named tables when combining multiple event type tables." + "Must give named columns when combining multiple event type tables." ) - self.date_subsets = self._make_table_list( + self.date_subsets = _make_table_list( + tables=self.tables, + start=self.start, + stop=self.stop, + columns=columns, hours=hours, subscriber_subset=subscriber_subset, subscriber_identifier=subscriber_identifier, @@ -81,48 +85,6 @@ def column_names(self) -> List[str]: 0 ].column_names # Use in preference to self.columns which might be ["*"] - def _parse_tables(self, tables): - if tables is None: - return [f"events.{t}" for t in get_db().subscriber_tables] - elif isinstance(tables, str) and len(tables) > 0: - return [tables] - elif isinstance(tables, str): - raise ValueError("Empty table name.") - elif not isinstance(tables, list) or not all( - [isinstance(tbl, str) for tbl in tables] - ): - raise ValueError("Tables must be a string or list of strings.") - elif len(tables) == 0: - raise ValueError("Empty tables list.") - else: - return tables - - def _make_table_list(self, *, hours, subscriber_subset, subscriber_identifier): - """ - Makes a list of EventTableSubset queries. - """ - - date_subsets = [] - for table in self.tables: - try: - sql = EventTableSubset( - start=self.start, - stop=self.stop, - table=table, - columns=self.columns, - hours=hours, - subscriber_subset=subscriber_subset, - subscriber_identifier=subscriber_identifier, - ) - date_subsets.append(sql) - except MissingDateError: - warnings.warn( - f"No data in {table} for {self.start}–{self.stop}", stacklevel=2 - ) - if not date_subsets: - raise MissingDateError(self.start, self.stop) - return date_subsets - def _make_query(self): # Get the list of tables, select the relevant columns and union # them all @@ -134,3 +96,50 @@ def _make_query(self): def fully_qualified_table_name(self): # EventTableSubset are a simple select from events, and should not be cached raise NotImplementedError + + +def _parse_tables(tables): + if tables is None: + return ( + "calls", + "sms", + ) # This should default to all the tables really, but that would break all the tests + if tables == "": + raise ValueError("Empty table name.") + elif isinstance(tables, str): + return [tables] + elif not isinstance(tables, list) or not all( + [isinstance(tbl, str) for tbl in tables] + ): + raise ValueError("Tables must be a string or list of strings.") + elif len(tables) == 0: + raise ValueError("Empty tables list.") + else: + return sorted(set(tables)) + + +def _make_table_list( + *, tables, start, stop, columns, hours, subscriber_subset, subscriber_identifier +): + """ + Makes a list of EventTableSubset queries. + """ + + date_subsets = [] + for table in tables: + try: + sql = EventTableSubset( + start=start, + stop=stop, + table=table, + columns=columns, + hours=hours, + subscriber_subset=subscriber_subset, + subscriber_identifier=subscriber_identifier, + ) + date_subsets.append(sql) + except MissingDateError: + warnings.warn(f"No data in {table} for {start}–{stop}", stacklevel=2) + if not date_subsets: + raise MissingDateError(start, stop) + return date_subsets diff --git a/flowmachine/flowmachine/features/utilities/sets.py b/flowmachine/flowmachine/features/utilities/sets.py index 1ff69fc019..a4bacea0c5 100644 --- a/flowmachine/flowmachine/features/utilities/sets.py +++ b/flowmachine/flowmachine/features/utilities/sets.py @@ -46,9 +46,7 @@ class UniqueSubscribers(Query): table : str, default 'all' Table on which to perform the query. By default it will look at ALL tables, which are any tables with subscriber information - in them, specified via subscriber_tables in flowmachine.yml. Otherwise - you need to specify a full table (with a schema) such as - 'events.calls'. + in them, specified via subscriber_tables in flowmachine.yml. subscriber_identifier : {'msisdn', 'imei'}, default 'msisdn' Either msisdn, or imei, the column that identifies the subscriber. subscriber_subset : str, list, flowmachine.core.Query, flowmachine.core.Table, default None diff --git a/flowmachine/tests/conftest.py b/flowmachine/tests/conftest.py index 103461aed5..1bccfd19f1 100644 --- a/flowmachine/tests/conftest.py +++ b/flowmachine/tests/conftest.py @@ -25,7 +25,7 @@ ) import flowmachine -from flowmachine.core import make_spatial_unit +from flowmachine.core import make_spatial_unit, Table from flowmachine.core.cache import reset_cache from flowmachine.core.context import ( redis_connection, @@ -103,12 +103,12 @@ def parse_json(): { "spatial_unit_type": "polygon", "region_id_column_name": "admin3pcod", - "geom_table": "geography.admin3", + "geom_table": Table("geography.admin3", columns=["geom", "admin3pcod"]), }, { "spatial_unit_type": "polygon", "region_id_column_name": "id AS site_id", - "geom_table": "infrastructure.sites", + "geom_table": Table("infrastructure.sites", columns=["geom_point", "id"]), "geom_column": "geom_point", }, ], @@ -149,7 +149,7 @@ def skip_datecheck(request, monkeypatch): """ run_date_checks = request.node.get_closest_marker("check_available_dates", False) if not run_date_checks: - monkeypatch.setattr(EventTableSubset, "_check_dates", lambda x: True) + monkeypatch.setattr(EventTableSubset, "check_dates", lambda x: True) @pytest.fixture(autouse=True) diff --git a/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_1_sql.approved.txt b/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_1_sql.approved.txt index 3747f837cf..d6f2f54199 100644 --- a/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_1_sql.approved.txt +++ b/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_1_sql.approved.txt @@ -27,25 +27,8 @@ FROM (SELECT subscriber, loc_table.date_of_last_service, geom_table.admin3pcod AS pcod FROM infrastructure.cells AS loc_table - INNER JOIN (SELECT gid, - admin0name, - admin0pcod, - admin1name, - admin1pcod, - admin2name, - admin2pcod, + INNER JOIN (SELECT admin3pcod, admin3name, - admin3pcod, - admin3refn, - admin3altn, - admin3al_1, - date, - validon, - validto, - shape_star, - shape_stle, - shape_leng, - shape_area, geom FROM geography.admin3) AS geom_table ON st_within(CAST(loc_table.geom_point AS geometry), CAST(st_setsrid(geom_table.geom, 4326) AS geometry))) AS sites ON l.location_id = sites.location_id diff --git a/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_2_sql.approved.txt b/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_2_sql.approved.txt index 9b2e96744b..c735dbb079 100644 --- a/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_2_sql.approved.txt +++ b/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_2_sql.approved.txt @@ -42,23 +42,8 @@ FROM (SELECT times_visited.subscriber, loc_table.date_of_last_service, geom_table.admin2pcod AS pcod FROM infrastructure.cells AS loc_table - INNER JOIN (SELECT gid, - admin0name, - admin0pcod, - admin1name, - admin1pcod, + INNER JOIN (SELECT admin2pcod, admin2name, - admin2pcod, - admin2refn, - admin2altn, - admin2al_1, - date, - validon, - validto, - shape_star, - shape_stle, - shape_leng, - shape_area, geom FROM geography.admin2) AS geom_table ON st_within(CAST(loc_table.geom_point AS geometry), CAST(st_setsrid(geom_table.geom, 4326) AS geometry))) AS sites ON l.location_id = sites.location_id diff --git a/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_4_sql.approved.txt b/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_4_sql.approved.txt index 56ee111386..291ab04249 100644 --- a/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_4_sql.approved.txt +++ b/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_4_sql.approved.txt @@ -25,25 +25,8 @@ FROM (SELECT subscriber, loc_table.date_of_last_service, geom_table.admin3pcod AS pcod FROM infrastructure.cells AS loc_table - INNER JOIN (SELECT gid, - admin0name, - admin0pcod, - admin1name, - admin1pcod, - admin2name, - admin2pcod, + INNER JOIN (SELECT admin3pcod, admin3name, - admin3pcod, - admin3refn, - admin3altn, - admin3al_1, - date, - validon, - validto, - shape_star, - shape_stle, - shape_leng, - shape_area, geom FROM geography.admin3) AS geom_table ON st_within(CAST(loc_table.geom_point AS geometry), CAST(st_setsrid(geom_table.geom, 4326) AS geometry))) AS sites ON l.location_id = sites.location_id diff --git a/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_6_sql.approved.txt b/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_6_sql.approved.txt index 4d7fb8cb73..8d7f2008ce 100644 --- a/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_6_sql.approved.txt +++ b/flowmachine/tests/functional_tests/approved_files/test_sql_strings_and_results.test_daily_location_6_sql.approved.txt @@ -28,25 +28,8 @@ FROM (SELECT subscriber, loc_table.date_of_last_service, geom_table.admin3pcod AS pcod FROM infrastructure.cells AS loc_table - INNER JOIN (SELECT gid, - admin0name, - admin0pcod, - admin1name, - admin1pcod, - admin2name, - admin2pcod, + INNER JOIN (SELECT admin3pcod, admin3name, - admin3pcod, - admin3refn, - admin3altn, - admin3al_1, - date, - validon, - validto, - shape_star, - shape_stle, - shape_leng, - shape_area, geom FROM geography.admin3) AS geom_table ON st_within(CAST(loc_table.geom_point AS geometry), CAST(st_setsrid(geom_table.geom, 4326) AS geometry))) AS sites ON l.location_id = sites.location_id diff --git a/flowmachine/tests/functional_tests/test_sql_strings_and_results.py b/flowmachine/tests/functional_tests/test_sql_strings_and_results.py index be52003b5f..bee087afb2 100644 --- a/flowmachine/tests/functional_tests/test_sql_strings_and_results.py +++ b/flowmachine/tests/functional_tests/test_sql_strings_and_results.py @@ -124,7 +124,7 @@ def test_daily_location_4_sql(diff_reporter): ) dl = daily_location( "2016-01-05", - table="events.calls", + table="calls", hours=(22, 6), subscriber_subset=subset_query, ) @@ -142,7 +142,7 @@ def test_daily_location_4_df(get_dataframe, diff_reporter): ) dl = daily_location( "2016-01-05", - table="events.calls", + table="calls", hours=(22, 6), subscriber_subset=subset_query, ) @@ -210,9 +210,7 @@ def test_daily_location_6_sql(diff_reporter): """, ["subscriber"], ) - dl = daily_location( - "2016-01-03", table="events.calls", subscriber_subset=subset_query - ) + dl = daily_location("2016-01-03", table="calls", subscriber_subset=subset_query) sql = pretty_sql(dl.get_query()) diff_reporter(sql) @@ -229,9 +227,7 @@ def test_daily_location_6_df(get_dataframe, diff_reporter): """, ["outgoing", "datetime", "duration", "subscriber"], ) - dl = daily_location( - "2016-01-03", table="events.calls", subscriber_subset=subset_query - ) + dl = daily_location("2016-01-03", table="calls", subscriber_subset=subset_query) df = get_dataframe(dl) diff_reporter(df.to_csv()) diff --git a/flowmachine/tests/server/test_action_handlers.py b/flowmachine/tests/server/test_action_handlers.py index 876107ac87..bf224b5b2e 100644 --- a/flowmachine/tests/server/test_action_handlers.py +++ b/flowmachine/tests/server/test_action_handlers.py @@ -187,10 +187,10 @@ async def test_run_query_error_handled(dummy_redis, server_config): ), ) assert msg.status == ZMQReplyStatus.ERROR - assert ( - msg.msg - == "Internal flowmachine server error: could not create query object using query schema. The original error was: 'zip() argument after * must be an iterable, not Mock'" + assert msg.msg.rstrip().startswith( + "Internal flowmachine server error: could not create query object using query schema. The original error was: 'Pre-flight failed for '7b71413efc91213e798ca3bd53107186'. Errors:" ) + assert len(msg.payload["errors"]) == 3 @pytest.mark.parametrize("bad_unit", ["NOT_A_VALID_UNIT", "admin4"]) diff --git a/flowmachine/tests/test_active_subscribers.py b/flowmachine/tests/test_active_subscribers.py index 2f610cc766..58cebaef15 100644 --- a/flowmachine/tests/test_active_subscribers.py +++ b/flowmachine/tests/test_active_subscribers.py @@ -19,7 +19,7 @@ def test_active_subscribers_one_day(get_dataframe): total_major_periods=1, minor_period_threshold=3, major_period_threshold=1, - tables=["events.calls"], + tables=["calls"], ) out = get_dataframe(active_subscribers).iloc[0:5] print(out) @@ -44,7 +44,7 @@ def test_active_subscribers_many_days(get_dataframe): total_major_periods=4, minor_period_threshold=1, major_period_threshold=3, - tables=["events.calls"], + tables=["calls"], ) out = get_dataframe(active_subscribers).iloc[0:5] print(out) @@ -72,7 +72,7 @@ def test_active_subscribers_custom_period(get_dataframe): minor_period_threshold=2, major_period_threshold=3, period_unit="minutes", - tables=["events.calls"], + tables=["calls"], ) assert len(active_subscribers.major_period_queries) == 4 assert active_subscribers.major_period_queries[2].start == "2016-01-01 21:00:00" diff --git a/flowmachine/tests/test_cache.py b/flowmachine/tests/test_cache.py index 68efd8ad92..b49e499cd4 100644 --- a/flowmachine/tests/test_cache.py +++ b/flowmachine/tests/test_cache.py @@ -39,6 +39,7 @@ def test_do_cache_simple(flowmachine_connect): """ dl1 = daily_location("2016-01-01") + dl1.preflight() with get_db().engine.begin() as trans: write_cache_metadata(trans, dl1) assert cache_table_exists(get_db(), dl1.query_id) @@ -51,6 +52,7 @@ def test_do_cache_multi(flowmachine_connect): """ hl1 = ModalLocation(daily_location("2016-01-01"), daily_location("2016-01-02")) + hl1.preflight() with get_db().engine.begin() as trans: write_cache_metadata(trans, hl1) @@ -65,6 +67,7 @@ def test_do_cache_nested(flowmachine_connect): hl1 = ModalLocation(daily_location("2016-01-01"), daily_location("2016-01-02")) hl2 = ModalLocation(daily_location("2016-01-03"), daily_location("2016-01-04")) flow = Flows(hl1, hl2) + flow.preflight() with get_db().engine.begin() as trans: write_cache_metadata(trans, flow) @@ -144,7 +147,9 @@ def test_invalidate_cache_multi(flowmachine_connect): assert not cache_table_exists(get_db(), dl1.query_id) assert not cache_table_exists(get_db(), hl1.query_id) has_deps = bool(get_db().fetch("SELECT * FROM cache.dependencies")) - assert has_deps # the remaining dependencies are due to underlying Table objects + assert ( + not has_deps + ) # the remaining dependencies are due to underlying Table objects def test_invalidate_cache_midchain(flowmachine_connect): @@ -308,7 +313,7 @@ def column_names(self): return ["value"] def _make_query(self): - return "select 1 as value" + return self.nested.get_query() q = NestTestQuery() q_id = q.query_id diff --git a/flowmachine/tests/test_cache_utils.py b/flowmachine/tests/test_cache_utils.py index 04eed46190..6596a26ac7 100644 --- a/flowmachine/tests/test_cache_utils.py +++ b/flowmachine/tests/test_cache_utils.py @@ -56,7 +56,7 @@ def test_scoring(flowmachine_connect): dl_time = get_compute_time(get_db(), dl.query_id) dl_size = get_size_of_table(get_db(), dl.table_name, "cache") initial_score = get_score(get_db(), dl.query_id) - cachey_scorer = Scorer(halflife=1000.0) + cachey_scorer = Scorer(halflife=get_cache_half_life(get_db())) cache_score = cachey_scorer.touch("dl", dl_time / dl_size) assert cache_score == pytest.approx(initial_score) @@ -78,6 +78,7 @@ def test_touch_cache_record_for_query(flowmachine_connect): Touching a cache record for a query should update access count, last accessed, & counter. """ table = daily_location("2016-01-01").store().result() + initial_touches = get_db().fetch("SELECT nextval('cache.cache_touches');")[0][0] assert ( 1 @@ -96,7 +97,10 @@ def test_touch_cache_record_for_query(flowmachine_connect): )[0][0] ) # Two cache touches should have been recorded - assert 4 == get_db().fetch("SELECT nextval('cache.cache_touches');")[0][0] + assert ( + initial_touches + 2 + == get_db().fetch("SELECT nextval('cache.cache_touches');")[0][0] + ) assert ( accessed_at < get_db().fetch( @@ -109,7 +113,8 @@ def test_touch_cache_record_for_table(flowmachine_connect): """ Touching a cache record for a table should update access count and last accessed but not touch score, or counter. """ - table = Table("events.calls_20160101") + table = Table("events.calls_20160101", columns=["id"]) + table.preflight() with get_db().engine.begin() as conn: conn.exec_driver_sql( f"UPDATE cache.cached SET compute_time = 1 WHERE query_id=%(ident)s", @@ -525,6 +530,7 @@ def test_cache_reset_protects_tables(flowmachine_connect): """ # Regression test for https://github.com/Flowminder/FlowKit/issues/832 dl_query = daily_location(date="2016-01-03", method="last") + dl_query.preflight() reset_cache(get_db(), get_redis()) for dep in dl_query._get_stored_dependencies(): assert dep.query_id in [x.query_id for x in Query.get_stored()] diff --git a/flowmachine/tests/test_daily_location.py b/flowmachine/tests/test_daily_location.py index 818e64ffb4..775f5c2439 100644 --- a/flowmachine/tests/test_daily_location.py +++ b/flowmachine/tests/test_daily_location.py @@ -6,6 +6,7 @@ from flowmachine.core.errors import MissingDateError from flowmachine.core import make_spatial_unit +from flowmachine.core.errors.flowmachine_errors import PreFlightFailedException from flowmachine.features import daily_location, MostFrequentLocation @@ -77,5 +78,13 @@ def test_daily_locs_errors(): daily_location() errors when we ask for a date that does not exist. """ + with pytest.raises(PreFlightFailedException) as exc: + daily_location("2016-01-31").preflight() with pytest.raises(MissingDateError): - daily_location("2016-01-31") + raise exc.value.errors[ + "" + ][0] + with pytest.raises(MissingDateError): + raise exc.value.errors[ + "" + ][0] diff --git a/flowmachine/tests/test_dependency_graph.py b/flowmachine/tests/test_dependency_graph.py index 5b665c79c7..a1641b8f48 100644 --- a/flowmachine/tests/test_dependency_graph.py +++ b/flowmachine/tests/test_dependency_graph.py @@ -49,7 +49,11 @@ def test_print_dependency_tree(): expected_output = textwrap.dedent( """\ + - + - - + - + - - - - @@ -57,15 +61,9 @@ def test_print_dependency_tree(): - - - - - - - - - - - - - - - - - - """ ) diff --git a/flowmachine/tests/test_event_table_subset.py b/flowmachine/tests/test_event_table_subset.py index 8bbc60715c..a88f2910c0 100644 --- a/flowmachine/tests/test_event_table_subset.py +++ b/flowmachine/tests/test_event_table_subset.py @@ -14,6 +14,7 @@ import flowmachine.core from flowmachine.core.errors import MissingDateError +from flowmachine.core.errors.flowmachine_errors import PreFlightFailedException from flowmachine.features.utilities.event_table_subset import EventTableSubset @@ -47,7 +48,7 @@ def test_warns_on_missing(): """ message = "115 of 122 calendar dates missing. Earliest date is 2016-01-01, latest is 2016-01-07" with pytest.warns(UserWarning, match=message): - EventTableSubset(start="2016-01-01", stop="2016-05-02") + EventTableSubset(start="2016-01-01", stop="2016-05-02").preflight() @pytest.mark.check_available_dates @@ -55,10 +56,8 @@ def test_error_on_all_missing(): """ Date subsetter should error when all dates are missing. """ - with pytest.raises(MissingDateError): - EventTableSubset(start="2016-05-01", stop="2016-05-02") - with pytest.raises(MissingDateError): - EventTableSubset(start="2016-05-01", stop="2016-05-02", table="events.topups") + with pytest.raises(PreFlightFailedException) as exc: + EventTableSubset(start="2016-05-01", stop="2016-05-02").preflight() def test_handles_mins(get_dataframe): diff --git a/flowmachine/tests/test_events_table_union.py b/flowmachine/tests/test_events_table_union.py index 2680a12279..3296839a4c 100644 --- a/flowmachine/tests/test_events_table_union.py +++ b/flowmachine/tests/test_events_table_union.py @@ -8,12 +8,12 @@ @pytest.mark.parametrize( - "columns", [["msisdn"], ["*"], ["id", "msisdn"]], ids=lambda x: f"{x}" + "columns", [["msisdn"], ["id", "msisdn"]], ids=lambda x: f"{x}" ) def test_events_tables_union_column_names(columns): """Test that EventsTableUnion column_names property is accurate.""" etu = EventsTablesUnion( - "2016-01-01", "2016-01-02", columns=columns, tables=["events.calls"] + "2016-01-01", "2016-01-02", columns=columns, tables=["calls"] ) assert etu.head(0).columns.tolist() == etu.column_names @@ -25,7 +25,7 @@ def test_events_table_union_subscriber_ident_substitutions(ident): "2016-01-01", "2016-01-02", columns=[ident], - tables=["events.calls"], + tables=["calls"], subscriber_identifier=ident, ) assert "subscriber" == etu.head(0).columns[0] @@ -57,7 +57,7 @@ def test_get_only_sms(get_length): "2016-01-01", "2016-01-02", columns=["msisdn", "msisdn_counterpart", "datetime"], - tables="events.sms", + tables="sms", ) assert get_length(etu) == 1246 @@ -94,6 +94,6 @@ def test_get_list_of_tables(get_length): "2016-01-01", "2016-01-02", columns=["msisdn", "msisdn_counterpart", "datetime"], - tables=["events.calls", "events.sms"], + tables=["calls", "sms"], ) assert get_length(etu) == 2500 diff --git a/flowmachine/tests/test_geotable.py b/flowmachine/tests/test_geotable.py index a5627feb19..ec981d55a2 100644 --- a/flowmachine/tests/test_geotable.py +++ b/flowmachine/tests/test_geotable.py @@ -6,14 +6,16 @@ def test_geotable_bad_params(): """Test that geotable raises errors correctly.""" with pytest.raises(ValueError): - t = GeoTable("geography.admin3", geom_column="bad_column") + t = GeoTable("geography.admin3", geom_column="bad_column", columns=["geom"]) with pytest.raises(ValueError): - t = GeoTable("geography.admin3", gid_column="bad_column") + t = GeoTable( + "geography.admin3", gid_column="bad_column", columns=["geom", "gid"] + ) def test_geotable(): """Test that geotable will work with an obviously geographic table.""" - t = GeoTable("geography.admin3") + t = GeoTable("geography.admin3", columns=["geom", "admin3pcod", "admin0name"]) feature = t.to_geojson()["features"][0] assert feature["properties"]["admin0name"] == "Nepal" @@ -36,8 +38,7 @@ def test_geotable_uses_supplied_gid(): assert feature["id"] == "Sindhupalchok" -@pytest.mark.parametrize("columns", [None, ["gid", "geom"]]) -def test_geotable_column_names(columns): +def test_geotable_column_names(): """Test that column_names property matches head(0) for geotables""" - t = GeoTable("geography.admin3", columns=columns) + t = GeoTable("geography.admin3", columns=["gid", "geom"]) assert t.head(0).columns.tolist() == t.column_names diff --git a/flowmachine/tests/test_join_to_location.py b/flowmachine/tests/test_join_to_location.py index d5b905f96f..f162165dbb 100644 --- a/flowmachine/tests/test_join_to_location.py +++ b/flowmachine/tests/test_join_to_location.py @@ -9,7 +9,12 @@ import numpy as np from flowmachine.features import SubscriberLocations -from flowmachine.core import JoinToLocation, location_joined_query, make_spatial_unit +from flowmachine.core import ( + JoinToLocation, + location_joined_query, + make_spatial_unit, + Table, +) from flowmachine.core.errors import InvalidSpatialUnitError @@ -124,7 +129,7 @@ def test_join_with_polygon(get_dataframe, get_length): spatial_unit=make_spatial_unit( "polygon", region_id_column_name="admin3pcod", - geom_table="geography.admin3", + geom_table=Table("geography.admin3", columns=["admin3pcod", "geom"]), geom_column="geom", ), ) diff --git a/flowmachine/tests/test_query.py b/flowmachine/tests/test_query.py index 523b824262..8b7d254b9f 100644 --- a/flowmachine/tests/test_query.py +++ b/flowmachine/tests/test_query.py @@ -81,11 +81,11 @@ def test_is_stored(): class storable_query(Query): def _make_query(self): - return """SELECT 1""" + return """SELECT 1 as col""" @property def column_names(self) -> List[str]: - return ["1"] + return ["col"] sq = storable_query() sq.invalidate_db_cache() diff --git a/flowmachine/tests/test_query_object_construction.test_construct_query.approved.txt b/flowmachine/tests/test_query_object_construction.test_construct_query.approved.txt index d4649bde87..44848428d5 100644 --- a/flowmachine/tests/test_query_object_construction.test_construct_query.approved.txt +++ b/flowmachine/tests/test_query_object_construction.test_construct_query.approved.txt @@ -1,5 +1,5 @@ { - "261b85fc9c64253889174ea8151f1d69": { + "1cc95af30e8b456123dab6ae4cc1f451": { "query_kind": "spatial_aggregate", "locations": { "query_kind": "daily_location", @@ -20,7 +20,7 @@ } } }, - "e142ffb3174add433422c2724a08c02b": { + "7b71413efc91213e798ca3bd53107186": { "query_kind": "spatial_aggregate", "locations": { "query_kind": "daily_location", @@ -32,7 +32,7 @@ "sampling": null } }, - "93580818b4cf2b30071c3c46c09e9de1": { + "0b3536e28692e84e84ba98c64df38280": { "query_kind": "location_event_counts", "start_date": "2016-01-01", "end_date": "2016-01-02", @@ -42,7 +42,7 @@ "event_types": null, "subscriber_subset": null }, - "b48d02838163766771eeed8cd8aafd98": { + "2b9d575accc718b8fba71457ea8acabf": { "query_kind": "spatial_aggregate", "locations": { "query_kind": "modal_location", @@ -64,11 +64,11 @@ ] } }, - "6521353e7563ed700dfd2cf90721934b": { + "ecff278a078630433a0f54f45ab0bcdc": { "query_kind": "geography", "aggregation_unit": "admin3" }, - "75b0532a747b473f69b45b78fdc29865": { + "05f6769184c4ad4571ec111b1a7b3029": { "query_kind": "meaningful_locations_aggregate", "aggregation_unit": "admin1", "start_date": "2016-01-01", @@ -161,7 +161,7 @@ "tower_cluster_call_threshold": 0, "subscriber_subset": null }, - "5df56eee4dc96ed961f8eb76583cc6de": { + "d7192263d6ac8f3826ad3d9edd92a255": { "query_kind": "meaningful_locations_between_label_od_matrix", "aggregation_unit": "admin1", "start_date": "2016-01-01", @@ -256,7 +256,7 @@ "event_types": null, "subscriber_subset": null }, - "00bee92b22b1ceb98244c5700a079656": { + "269b9d1f4ec72727ff2f9ced0027d356": { "query_kind": "meaningful_locations_between_dates_od_matrix", "aggregation_unit": "admin1", "start_date_a": "2016-01-01", @@ -355,7 +355,7 @@ ], "subscriber_subset": null }, - "c3dc95da89aeb0908b9398918dab6f26": { + "6f7a3eb1b622ae11f4fe8b72ed567068": { "query_kind": "flows", "from_location": { "query_kind": "daily_location", @@ -371,7 +371,7 @@ }, "join_type": "left outer" }, - "95fbc18554e15733df47bfd5cbaa3f87": { + "3b0cfae8f25961e166d5659130caa2fb": { "query_kind": "flows", "from_location": { "query_kind": "majority_location", @@ -418,7 +418,7 @@ } } }, - "920dfdf5568d75921c4173da8bccc6ef": { + "4e93b642dc8656bc73e116d17f362ce9": { "query_kind": "labelled_spatial_aggregate", "locations": { "query_kind": "coalesced_location", @@ -640,7 +640,7 @@ "stay_length_threshold": 2 } }, - "c915d87b66df83904634990e2f78da9e": { + "a18bd5d93c83c7da665815a8b255f36c": { "query_kind": "labelled_flows", "from_location": { "query_kind": "coalesced_location", diff --git a/flowmachine/tests/test_query_union.py b/flowmachine/tests/test_query_union.py index f86d1fbe30..04332e386f 100644 --- a/flowmachine/tests/test_query_union.py +++ b/flowmachine/tests/test_query_union.py @@ -10,7 +10,9 @@ def test_union_column_names(): """Test that Union's column_names property is accurate""" - union = Table("events.calls_20160101").union(Table("events.calls_20160102")) + union = Table("events.calls_20160101", columns=["id"]).union( + Table("events.calls_20160102", columns=["id"]) + ) assert union.head(0).columns.tolist() == union.column_names @@ -18,7 +20,7 @@ def test_union_all(get_dataframe): """ Test default union behaviour keeps duplicates. """ - q1 = Table(schema="events", name="calls") + q1 = Table(schema="events", name="calls", columns=["id"]) union_all = q1.union(q1) union_all_df = get_dataframe(union_all) single_id = union_all_df[union_all_df.id == "5wNJA-PdRJ4-jxEdG-yOXpZ"] @@ -29,7 +31,7 @@ def test_union(get_dataframe): """ Test union with all set to false dedupes. """ - q1 = Table(schema="events", name="calls") + q1 = Table(schema="events", name="calls", columns=["id", "msisdn"]) union = q1.union(q1, all=False) union_df = get_dataframe(union) single_id = union_df[union_df.id == "5wNJA-PdRJ4-jxEdG-yOXpZ"] @@ -42,7 +44,7 @@ def test_union_raises_with_mismatched_columns(): """ with pytest.raises(ValueError): Table(schema="events", name="calls", columns=["msisdn"]).union( - Table(schema="events", name="calls") + Table(schema="events", name="calls", columns=["id"]) ) diff --git a/flowmachine/tests/test_random.py b/flowmachine/tests/test_random.py index 3903a1d0da..56fbdd159b 100644 --- a/flowmachine/tests/test_random.py +++ b/flowmachine/tests/test_random.py @@ -11,6 +11,7 @@ import pytest import pickle +from flowmachine.core.errors.flowmachine_errors import PreFlightFailedException from flowmachine.core.mixins import GraphMixin from flowmachine.features import daily_location, Flows from flowmachine.features.utilities.sets import UniqueSubscribers @@ -233,9 +234,11 @@ def test_system_rows_fail_with_inheritance(): """ Test whether the system row method fails if the subscriber queries for random rows on a parent table. """ - with pytest.raises(ValueError): - df = Table(name="events.calls").random_sample( - size=8, sampling_method="system_rows" + with pytest.raises(PreFlightFailedException): + df = ( + Table(name="events.calls", columns=["msisdn"]) + .random_sample(size=8, sampling_method="system_rows") + .preflight() ) @@ -290,7 +293,7 @@ def test_pickling(): ss1 = UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( size=10, sampling_method="system_rows" ) - ss2 = Table("events.calls").random_sample( + ss2 = Table("events.calls", columns=["id"]).random_sample( size=10, sampling_method="bernoulli", seed=0.73 ) for ss in [ss1, ss2]: diff --git a/flowmachine/tests/test_raster_statistics.py b/flowmachine/tests/test_raster_statistics.py index ce52185e06..058583633d 100644 --- a/flowmachine/tests/test_raster_statistics.py +++ b/flowmachine/tests/test_raster_statistics.py @@ -17,7 +17,7 @@ def test_computes_expected_clipping_values(get_dataframe): RasterStatistics() returns correct values when clipping vector and raster layers. """ G = "admin2pcod" - vector = Table(schema="geography", name="admin2") + vector = Table(schema="geography", name="admin2", columns=["admin2pcod", "geom"]) r = RasterStatistics( raster="population.small_nepal_raster", vector=vector, grouping_element=G ) @@ -43,7 +43,7 @@ def test_raises_notimplemented_when_wrong_statistic_requested(): RasterStatistics() raises NotImplementedError if wrong statistic requested. """ G = "admin2pcod" - vector = Table(schema="geography", name="admin2") + vector = Table(schema="geography", name="admin2", columns=["admin2pcod", "geom"]) with pytest.raises(NotImplementedError): r = RasterStatistics( raster="population.small_nepal_raster", @@ -58,7 +58,7 @@ def test_raises_valueerror_when_grouping_element_not_provided(): RasterStatistics() raises ValueError when `grouping_element` not provided. """ G = None - vector = Table(schema="geography", name="admin2") + vector = Table(schema="geography", name="admin2", columns=["admin2pcod", "geom"]) with pytest.raises(ValueError): r = RasterStatistics( "population.small_nepal_raster", vector=vector, grouping_element=None @@ -79,7 +79,7 @@ def test_raster_statistics_column_names_vector(get_dataframe): Test that column_names property matches head(0) for RasterStatistics when vector is not None """ - vector = Table(schema="geography", name="admin2") + vector = Table(schema="geography", name="admin2", columns=["admin2pcod", "geom"]) r = RasterStatistics( raster="population.small_nepal_raster", vector=vector, diff --git a/flowmachine/tests/test_redacted_total_events.py b/flowmachine/tests/test_redacted_total_events.py index 6e92f6a58b..aaf2a0b189 100644 --- a/flowmachine/tests/test_redacted_total_events.py +++ b/flowmachine/tests/test_redacted_total_events.py @@ -23,13 +23,13 @@ def test_all_above_threshold(get_dataframe): "2016-01-02", spatial_unit=make_spatial_unit("cell"), interval="day", - table=["events.calls"], + table=["calls"], ) ) us = get_dataframe( RedactedUniqueSubscriberCounts( unique_subscriber_counts=UniqueSubscriberCounts( - "2016-01-01", "2016-01-02", table=["events.calls"] + "2016-01-01", "2016-01-02", table=["calls"] ) ) ) @@ -49,7 +49,7 @@ def test_all_above_threshold_hour_bucket(get_dataframe): "2016-01-02", spatial_unit=make_spatial_unit("cell"), interval="hour", - table=["events.calls"], + table=["calls"], ) ) @@ -67,7 +67,7 @@ def test_all_above_threshold_minute_bucket(get_dataframe): "2016-01-01 13:00", spatial_unit=make_spatial_unit("cell"), interval="min", - table=["events.calls"], + table=["calls"], ) ) diff --git a/flowmachine/tests/test_redacted_unique_subscriber_counts.py b/flowmachine/tests/test_redacted_unique_subscriber_counts.py index 02be24bfbb..368c9f0bab 100644 --- a/flowmachine/tests/test_redacted_unique_subscriber_counts.py +++ b/flowmachine/tests/test_redacted_unique_subscriber_counts.py @@ -11,7 +11,7 @@ def test_all_above_threshold(get_dataframe): """ Test that all values in the redacted query are above the redaction threshold. """ - us = UniqueSubscriberCounts("2016-01-01", "2016-01-02", table=["events.calls"]) + us = UniqueSubscriberCounts("2016-01-01", "2016-01-02", table=["calls"]) rus_df = get_dataframe(RedactedUniqueSubscriberCounts(unique_subscriber_counts=us)) us_df = get_dataframe(us) assert all(rus_df.value > 15) diff --git a/flowmachine/tests/test_spatial_unit.py b/flowmachine/tests/test_spatial_unit.py index a33caadf23..2b53f52bc6 100644 --- a/flowmachine/tests/test_spatial_unit.py +++ b/flowmachine/tests/test_spatial_unit.py @@ -4,6 +4,7 @@ from flowmachine.core import CustomQuery from flowmachine.core.errors import InvalidSpatialUnitError +from flowmachine.core.infrastructure_table import SitesTable from flowmachine.core.spatial_unit import * import pytest @@ -47,7 +48,7 @@ def test_get_geom_query_column_names( { "spatial_unit_type": "polygon", "region_id_column_name": "id", - "geom_table": "infrastructure.sites", + "geom_table": SitesTable(), "geom_column": "geom_point", }, ["id"], @@ -56,7 +57,7 @@ def test_get_geom_query_column_names( { "spatial_unit_type": "polygon", "region_id_column_name": ["id"], - "geom_table": "infrastructure.sites", + "geom_table": SitesTable(), "geom_column": "geom_point", }, ["id"], @@ -97,7 +98,7 @@ def test_polygon_spatial_unit_column_list(): passed_cols = ["id"] psu = PolygonSpatialUnit( geom_table_column_names=passed_cols, - geom_table="infrastructure.sites", + geom_table=SitesTable(), geom_column="geom_point", ) loc_cols = psu.location_id_columns @@ -127,7 +128,9 @@ def test_missing_location_columns_raises_error(): { "spatial_unit_type": "polygon", "region_id_column_name": "id", - "geom_table": "infrastructure.sites", + "geom_table": Table( + "infrastructure.sites", columns=["id", "geom_point"] + ), "geom_column": "geom_point", }, "polygon", @@ -158,12 +161,12 @@ def test_canonical_names(make_spatial_unit_args, expected_name): { "spatial_unit_type": "polygon", "region_id_column_name": "admin3pcod", - "geom_table": "geography.admin3", + "geom_table": Table("geography.admin3", columns=["admin3pcod", "geom"]), }, { "spatial_unit_type": "polygon", "region_id_column_name": "id", - "geom_table": "infrastructure.sites", + "geom_table": Table("infrastructure.sites", columns=["id", "geom_point"]), "geom_column": "geom_point", }, ], diff --git a/flowmachine/tests/test_subscriber_call_durations.py b/flowmachine/tests/test_subscriber_call_durations.py index 29c0f149de..1c290cc76f 100644 --- a/flowmachine/tests/test_subscriber_call_durations.py +++ b/flowmachine/tests/test_subscriber_call_durations.py @@ -4,6 +4,7 @@ import pytest +from flowmachine.core import make_spatial_unit, Table from flowmachine.features.subscriber.subscriber_call_durations import * from flowmachine.core.statistic_types import Statistic @@ -48,7 +49,9 @@ def test_polygon_tables(get_dataframe): "2016-01-01", "2016-01-07", spatial_unit=make_spatial_unit( - "polygon", geom_table="geography.admin3", region_id_column_name="admin3name" + "polygon", + geom_table=Table("geography.admin3", columns=["geom", "admin3name"]), + region_id_column_name="admin3name", ), ) df = get_dataframe(per_location_durations) @@ -67,7 +70,9 @@ def test_polygon_tables(get_dataframe): "2016-01-01", "2016-01-07", spatial_unit=make_spatial_unit( - "polygon", geom_table="geography.admin3", region_id_column_name="admin3name" + "polygon", + geom_table=Table("geography.admin3", columns=["geom", "admin3name"]), + region_id_column_name="admin3name", ), ) diff --git a/flowmachine/tests/test_subscriber_degree.py b/flowmachine/tests/test_subscriber_degree.py index 0f83858941..d9e8b9f71b 100644 --- a/flowmachine/tests/test_subscriber_degree.py +++ b/flowmachine/tests/test_subscriber_degree.py @@ -25,12 +25,8 @@ def test_returns_correct_values(get_dataframe): """ # We expect subscriber '2Dq97XmPqvL6noGk' to have a single event in df1 # and two events in df2 (due to the larger time interval). - ud1 = SubscriberDegree( - "2016-01-01 12:35:00", "2016-01-01 12:40:00", tables="events.sms" - ) - ud2 = SubscriberDegree( - "2016-01-01 12:28:00", "2016-01-01 12:40:00", tables="events.sms" - ) + ud1 = SubscriberDegree("2016-01-01 12:35:00", "2016-01-01 12:40:00", tables="sms") + ud2 = SubscriberDegree("2016-01-01 12:28:00", "2016-01-01 12:40:00", tables="sms") df1 = get_dataframe(ud1).set_index("subscriber") df2 = get_dataframe(ud2).set_index("subscriber") @@ -49,13 +45,13 @@ def test_returns_correct_in_out_values(get_dataframe): ud1 = SubscriberDegree( "2016-01-01 12:35:00", "2016-01-01 12:40:00", - tables="events.sms", + tables="sms", direction="in", ) ud2 = SubscriberDegree( "2016-01-01 12:28:00", "2016-01-01 12:40:00", - tables="events.sms", + tables="sms", direction="out", ) diff --git a/flowmachine/tests/test_subscriber_event_count.py b/flowmachine/tests/test_subscriber_event_count.py index d0d14acc4b..dd4a521a1d 100644 --- a/flowmachine/tests/test_subscriber_event_count.py +++ b/flowmachine/tests/test_subscriber_event_count.py @@ -19,20 +19,16 @@ def test_event_count(get_dataframe): "2016-01-01", "2016-01-08", direction="both", - tables=["events.calls", "events.sms", "events.mds", "events.topups"], + tables=["calls", "sms", "mds", "topups"], ) df = get_dataframe(query).set_index("subscriber") assert df.loc["DzpZJ2EaVQo2X5vM"].value == 69 - query = EventCount( - "2016-01-01", "2016-01-08", direction="both", tables=["events.mds"] - ) + query = EventCount("2016-01-01", "2016-01-08", direction="both", tables=["mds"]) df = get_dataframe(query).set_index("subscriber") assert df.loc["E0LZAa7AyNd34Djq"].value == 8 - query = EventCount( - "2016-01-01", "2016-01-08", direction="both", tables="events.mds" - ) + query = EventCount("2016-01-01", "2016-01-08", direction="both", tables="mds") df = get_dataframe(query).set_index("subscriber") assert df.loc["E0LZAa7AyNd34Djq"].value == 8 @@ -80,6 +76,4 @@ def test_directed_count_undirected_tables_raises(): Test that requesting directed counts of undirected tables raises warning and errors. """ with pytest.raises(ValueError): - query = EventCount( - "2016-01-01", "2016-01-08", direction="out", tables=["events.mds"] - ) + query = EventCount("2016-01-01", "2016-01-08", direction="out", tables=["mds"]) diff --git a/flowmachine/tests/test_subscriber_event_type_proportion.py b/flowmachine/tests/test_subscriber_event_type_proportion.py index d14e139cce..50f10afedf 100644 --- a/flowmachine/tests/test_subscriber_event_type_proportion.py +++ b/flowmachine/tests/test_subscriber_event_type_proportion.py @@ -10,12 +10,12 @@ @pytest.mark.parametrize( "numerator, numerator_direction, msisdn, want", [ - ("events.calls", "both", "AgB6KR3Levd9Z1vJ", 0.351_852), - ("events.calls", "in", "AgB6KR3Levd9Z1vJ", 0.203_703_703_7), - ("events.sms", "both", "7ra3xZakjEqB1Al5", 0.362_069), - ("events.mds", "both", "QrAlXqDbXDkNJe3E", 0.236_363_63), - ("events.topups", "both", "bKZLwjrMQG7z468y", 0.183_098_5), - (["events.calls", "events.sms"], "both", "AgB6KR3Levd9Z1vJ", 0.648_148_1), + ("calls", "both", "AgB6KR3Levd9Z1vJ", 0.351_852), + ("calls", "in", "AgB6KR3Levd9Z1vJ", 0.203_703_703_7), + ("sms", "both", "7ra3xZakjEqB1Al5", 0.362_069), + ("mds", "both", "QrAlXqDbXDkNJe3E", 0.236_363_63), + ("topups", "both", "bKZLwjrMQG7z468y", 0.183_098_5), + (["calls", "sms"], "both", "AgB6KR3Levd9Z1vJ", 0.648_148_1), ], ) def test_proportion_event_type( @@ -30,11 +30,11 @@ def test_proportion_event_type( numerator, numerator_direction=numerator_direction, tables=[ - "events.calls", - "events.sms", - "events.mds", - "events.topups", - "events.forwards", + "calls", + "sms", + "mds", + "topups", + "forwards", ], ) df = get_dataframe(query).set_index("subscriber") diff --git a/flowmachine/tests/test_subscriber_location_cluster.py b/flowmachine/tests/test_subscriber_location_cluster.py index 8df3e105df..62f7ce249a 100644 --- a/flowmachine/tests/test_subscriber_location_cluster.py +++ b/flowmachine/tests/test_subscriber_location_cluster.py @@ -182,7 +182,10 @@ def test_different_call_days_format(get_dataframe): cd.store().result() har = get_dataframe( - HartiganCluster(calldays=Table(cd.fully_qualified_table_name), radius=50) + HartiganCluster( + calldays=Table(cd.fully_qualified_table_name, columns=cd.column_names), + radius=50, + ) ) assert isinstance(har, pd.DataFrame) diff --git a/flowmachine/tests/test_subscriber_location_subset.py b/flowmachine/tests/test_subscriber_location_subset.py index 313ceb6379..5b759b5bf7 100644 --- a/flowmachine/tests/test_subscriber_location_subset.py +++ b/flowmachine/tests/test_subscriber_location_subset.py @@ -28,7 +28,7 @@ def test_subscribers_make_atleast_one_call_in_admin0(): sls = SubscriberLocationSubset( start, stop, min_calls=1, spatial_unit=make_spatial_unit("admin", level=0) ) - us = UniqueSubscribers(start, stop, table="events.calls") + us = UniqueSubscribers(start, stop, table="calls") sls_subs = set(sls.get_dataframe()["subscriber"]) us_subs = set(us.get_dataframe()["subscriber"]) @@ -43,9 +43,9 @@ def test_subscribers_who_make_atleast_3_calls_in_central_development_region(): within Central Development admin1 region. """ start, stop = "2016-01-01", "2016-01-07" - regions = Table("admin2", "geography").subset( - "admin1name", ["Central Development Region"] - ) + regions = Table( + "admin2", schema="geography", columns=["admin1name", "admin2pcod", "geom"] + ).subset("admin1name", ["Central Development Region"]) sls = SubscriberLocationSubset( start, diff --git a/flowmachine/tests/test_subscriber_locations.py b/flowmachine/tests/test_subscriber_locations.py index 8f19dba31c..e1b151e2cb 100644 --- a/flowmachine/tests/test_subscriber_locations.py +++ b/flowmachine/tests/test_subscriber_locations.py @@ -4,7 +4,7 @@ import pytest -from flowmachine.core import make_spatial_unit +from flowmachine.core import make_spatial_unit, Table from flowmachine.features.utilities.subscriber_locations import SubscriberLocations pytestmark = pytest.mark.usefixtures("skip_datecheck") @@ -31,7 +31,9 @@ def test_can_get_pcods(get_dataframe): "2016-01-01 13:30:30", "2016-01-02 16:25:00", spatial_unit=make_spatial_unit( - "polygon", region_id_column_name="admin3pcod", geom_table="geography.admin3" + "polygon", + region_id_column_name="admin3pcod", + geom_table=Table("geography.admin3", columns=["admin3pcod", "geom"]), ), ) df = get_dataframe(subscriber_pcod) diff --git a/flowmachine/tests/test_subscriber_subsetting.py b/flowmachine/tests/test_subscriber_subsetting.py index 85eba01566..89a44bcd87 100644 --- a/flowmachine/tests/test_subscriber_subsetting.py +++ b/flowmachine/tests/test_subscriber_subsetting.py @@ -22,12 +22,12 @@ @pytest.mark.parametrize( - "columns", [["msisdn"], ["*"], ["id", "msisdn"]], ids=lambda x: f"{x}" + "columns", [["msisdn"], None, ["id", "msisdn"]], ids=lambda x: f"{x}" ) def test_events_table_subset_column_names(columns): """Test that EventTableSubset column_names property is accurate.""" etu = EventTableSubset( - start="2016-01-01", stop="2016-01-02", columns=columns, table="events.calls" + start="2016-01-01", stop="2016-01-02", columns=columns, table="calls" ) assert etu.head(0).columns.tolist() == etu.column_names @@ -39,7 +39,7 @@ def test_events_table_subscriber_ident_substitutions(ident): start="2016-01-01", stop="2016-01-02", columns=[ident], - table="events.calls", + table="calls", subscriber_identifier=ident, ) assert "subscriber" == etu.head(0).columns[0] @@ -73,9 +73,12 @@ def subscriber_list_table(subscriber_list, flowmachine_connect): formatted_subscribers ) trans.exec_driver_sql(sql) - subs_table = Table("subscriber_list") - yield subs_table - subs_table.invalidate_db_cache(drop=True) + subs_table = Table("subscriber_list", columns=["subscriber"]) + subs_table.preflight() + try: + yield subs_table + finally: + subs_table.invalidate_db_cache(drop=True) def test_cdrs_can_be_subset_by_table( diff --git a/flowmachine/tests/test_subsetting.py b/flowmachine/tests/test_subsetting.py index 24e1d1dae3..df581dbb92 100644 --- a/flowmachine/tests/test_subsetting.py +++ b/flowmachine/tests/test_subsetting.py @@ -177,7 +177,7 @@ def test_subset_subset(get_dataframe): sub_vala = "Central Development Region" sub_colb = "admin2name" sub_valb = "Bagmati" - t = Table("geography.admin3") + t = Table("geography.admin3", columns=[sub_cola, sub_colb]) t_df = get_dataframe(t) sub_q = t.subset(sub_cola, sub_vala).subset(sub_colb, sub_valb) @@ -199,7 +199,7 @@ def test_subset_subsetnumeric(get_dataframe): sub_colb = "shape_area" sub_lowb = 0.1 sub_highb = 0.12 - t = Table("geography.admin3") + t = Table("geography.admin3", columns=[sub_cola, sub_colb]) t_df = get_dataframe(t) sub_q1 = t.subset(sub_cola, sub_vala).numeric_subset(sub_colb, sub_lowb, sub_highb) @@ -227,7 +227,7 @@ def test_subsetnumeric_subsetnumeric(get_dataframe): sub_colb = "shape_leng" sub_lowb = 1.0 sub_highb = 2.0 - t = Table("geography.admin3") + t = Table("geography.admin3", columns=[sub_cola, sub_colb]) t_df = get_dataframe(t) sub_q = t.numeric_subset(sub_cola, sub_lowa, sub_lowb).numeric_subset( diff --git a/flowmachine/tests/test_table.py b/flowmachine/tests/test_table.py index 51955abf26..8f8396e7bc 100644 --- a/flowmachine/tests/test_table.py +++ b/flowmachine/tests/test_table.py @@ -3,34 +3,52 @@ import pytest from flowmachine.core import Table +from flowmachine.core.errors.flowmachine_errors import ( + QueryErroredException, + PreFlightFailedException, +) -@pytest.mark.parametrize("columns", [None, ["msisdn", "id"]]) +@pytest.mark.parametrize("columns", [["msisdn", "id"]]) def test_table_column_names(columns): """Test that column_names property matches head(0) for tables""" t = Table("events.calls", columns=columns) assert t.head(0).columns.tolist() == t.column_names -def test_table_init(): +@pytest.mark.parametrize( + "args", + [ + dict(name="events.calls", schema="extra_schema", columns=["id"]), + dict(name="calls", schema="events", columns=None), + dict(name="calls", schema="events", columns=[]), + ], +) +def test_table_init(args): """ Test that table creation handles params properly. """ - t = Table("events.calls") with pytest.raises(ValueError): - Table("events.calls", "moose") - with pytest.raises(ValueError): - Table("events.calls", columns="NO SUCH COLUMN") - with pytest.raises(ValueError): - Table("NOSUCHTABLE") - with pytest.raises(ValueError): - Table("events.WHAAAAAAAAT") + Table(**args) + +@pytest.mark.parametrize( + "args", + [ + dict(name="events.calls", columns=["NO SUCH COLUMN"]), + dict(name="NO SUCH TABLE", columns=["id"]), + ], +) +def test_table_preflight(args): + with pytest.raises(PreFlightFailedException): + Table(**args).preflight() -def public_schema_checked(): - """Test that where no schema is provided, public schema is checked.""" - t = Table("gambia_admin2") + +def test_public_schema_checked(): + """Test that where no schema is provided, user schema is checked.""" + t = Table("gambia_admin2", columns=["geom"]) + assert "flowmachine" == t.schema def test_children(): @@ -38,8 +56,8 @@ def test_children(): Test that table inheritance is correctly detected. """ - assert Table("events.calls").has_children() - assert not Table("geography.admin3").has_children() + assert Table("events.calls", columns=["id"]).has_children() + assert not Table("geography.admin3", columns=["geom"]).has_children() def test_columns(): @@ -54,7 +72,7 @@ def test_store_with_table(): """ Test that a subset of a table can be stored. """ - t = Table("events.calls") + t = Table("events.calls", columns=["id"]) s = t.subset("id", ["5wNJA-PdRJ4-jxEdG-yOXpZ", "5wNJA-PdRJ4-jxEdG-yOXpZ"]) s.store().result() assert s.is_stored @@ -73,23 +91,17 @@ def test_get_table_is_self(): def test_dependencies(): """ - Check that a table without explicit columns has no other queries as a dependency, - and a table with explicit columns has its parent table as a dependency. + Check that a table has no other queries as a dependency. """ - t1 = Table("events.calls") + t1 = Table("events.calls", columns=["id"]) assert t1.dependencies == set() - t2 = Table("events.calls", columns=["id"]) - assert len(t2.dependencies) == 1 - t2_parent = t2.dependencies.pop() - assert "057addedac04dbeb1dcbbb6b524b43f0" == t2_parent.query_id - def test_subset(): """ Test that a subset of a table doesn't show as stored. """ - ss = Table("events.calls").subset( + ss = Table("events.calls", columns=["id"]).subset( "id", ["5wNJA-PdRJ4-jxEdG-yOXpZ", "5wNJA-PdRJ4-jxEdG-yOXpZ"] ) assert not ss.is_stored @@ -99,7 +111,7 @@ def test_pickling(): """ Test that we can pickle and unpickle subset classes. """ - ss = Table("events.calls").subset( + ss = Table("events.calls", columns=["id"]).subset( "id", ["5wNJA-PdRJ4-jxEdG-yOXpZ", "5wNJA-PdRJ4-jxEdG-yOXpZ"] ) assert ss.get_query() == pickle.loads(pickle.dumps(ss)).get_query() diff --git a/flowmachine/tests/test_to_sql.py b/flowmachine/tests/test_to_sql.py index 5599789232..fea0fae1fb 100644 --- a/flowmachine/tests/test_to_sql.py +++ b/flowmachine/tests/test_to_sql.py @@ -46,7 +46,7 @@ def test_can_force_rewrite(flowmachine_connect, get_length): sql = """DELETE FROM tests.test_rewrite""" with get_db().engine.begin() as conn: conn.exec_driver_sql(sql) - assert 0 == get_length(Table("tests.test_rewrite")) + assert 0 == get_length(Table("tests.test_rewrite", columns=query.column_names)) query.invalidate_db_cache(name="test_rewrite", schema="tests") query.to_sql(name="test_rewrite", schema="tests").result() - assert 1 < get_length(Table("tests.test_rewrite")) + assert 1 < get_length(Table("tests.test_rewrite", columns=query.column_names)) diff --git a/flowmachine/tests/test_total_location_events.py b/flowmachine/tests/test_total_location_events.py index 264b2fc548..c671ad164e 100644 --- a/flowmachine/tests/test_total_location_events.py +++ b/flowmachine/tests/test_total_location_events.py @@ -56,7 +56,7 @@ def test_ignore_texts(get_dataframe): "2016-01-01", "2016-01-04", spatial_unit=make_spatial_unit("versioned-site"), - table="events.calls", + table="calls", ) df = get_dataframe(te) diff --git a/flowmachine/tests/test_total_network_objects.py b/flowmachine/tests/test_total_network_objects.py index 8f84d29c94..9931e837a0 100644 --- a/flowmachine/tests/test_total_network_objects.py +++ b/flowmachine/tests/test_total_network_objects.py @@ -135,6 +135,8 @@ def test_median_returns_correct_values(get_dataframe): """ instance = AggregateNetworkObjects( total_network_objects=TotalNetworkObjects( + start="2016-01-01", + stop="2016-01-08", table="calls", total_by="hour", network_object=make_spatial_unit("versioned-site"), diff --git a/flowmachine/tests/test_union_with_fixed_values.py b/flowmachine/tests/test_union_with_fixed_values.py index f5bdc1c3e7..0d651e3b99 100644 --- a/flowmachine/tests/test_union_with_fixed_values.py +++ b/flowmachine/tests/test_union_with_fixed_values.py @@ -13,19 +13,25 @@ def test_union_column_names(): """Test that Union's column_names property is accurate""" union = UnionWithFixedValues( - [Table("events.calls_20160101"), Table("events.calls_20160102")], + [ + Table("events.calls_20160101", columns=["msisdn"]), + Table("events.calls_20160102", columns=["msisdn"]), + ], ["extra_val", "extra_val"], fixed_value_column_name="extra_col", ) assert union.head(0).columns.tolist() == union.column_names - assert union.column_names == [*Table("events.calls_20160101").columns, "extra_col"] + assert union.column_names == [ + *Table("events.calls_20160101", columns=["msisdn"]).columns, + "extra_col", + ] def test_union_all(get_dataframe): """ Test default union behaviour keeps duplicates. """ - q1 = Table(schema="events", name="calls") + q1 = Table(schema="events", name="calls", columns=["id"]) union_all = q1.union(q1) union_all_df = get_dataframe(union_all) single_id = union_all_df[union_all_df.id == "5wNJA-PdRJ4-jxEdG-yOXpZ"] @@ -37,7 +43,10 @@ def test_union(get_dataframe): Test union adds extra columns. """ union = UnionWithFixedValues( - [Table("events.calls_20160101"), Table("events.calls_20160101")], + [ + Table("events.calls_20160101", columns=["msisdn"]), + Table("events.calls_20160101", columns=["msisdn"]), + ], ["extra_val", "extra_val_1"], fixed_value_column_name="extra_col", ) @@ -51,7 +60,10 @@ def test_union_date_type(get_dataframe): Test union casts types correctly for datetimes. """ union = UnionWithFixedValues( - [Table("events.calls_20160101"), Table("events.calls_20160101")], + [ + Table("events.calls_20160101", columns=["msisdn"]), + Table("events.calls_20160101", columns=["msisdn"]), + ], [datetime.datetime(2016, 1, 1), datetime.datetime(2016, 1, 2)], fixed_value_column_name="extra_col", ) diff --git a/integration_tests/tests/flowmachine_server_tests/helpers.py b/integration_tests/tests/flowmachine_server_tests/helpers.py index 1bdd7574e6..3b15625866 100644 --- a/integration_tests/tests/flowmachine_server_tests/helpers.py +++ b/integration_tests/tests/flowmachine_server_tests/helpers.py @@ -6,6 +6,8 @@ from flowmachine.core.server.utils import send_zmq_message_and_receive_reply +internal_tables = sorted(["cache_config", "cached", "dependencies", "zero_cache"]) + def poll_until_done(port, query_id, max_tries=100): """ @@ -39,9 +41,8 @@ def get_cache_tables(fm_conn, exclude_internal_tables=True): insp = inspect(fm_conn.engine) cache_tables = insp.get_table_names(schema="cache") if exclude_internal_tables: - cache_tables.remove("cached") - cache_tables.remove("dependencies") - cache_tables.remove("cache_config") + for table in internal_tables: + cache_tables.remove(table) return sorted(cache_tables) @@ -57,7 +58,7 @@ def cache_schema_is_empty(fm_conn, check_internal_tables_are_empty=True): cache_tables = sorted(insp.get_table_names(schema="cache")) # Check that there are no cached tables except the flowdb-internal ones - if cache_tables != ["cache_config", "cached", "dependencies"]: + if cache_tables != internal_tables: return False if check_internal_tables_are_empty: diff --git a/integration_tests/tests/flowmachine_tests/test_events_tables_union.py b/integration_tests/tests/flowmachine_tests/test_events_tables_union.py index 79c02e628c..7137e2a6c6 100644 --- a/integration_tests/tests/flowmachine_tests/test_events_tables_union.py +++ b/integration_tests/tests/flowmachine_tests/test_events_tables_union.py @@ -15,7 +15,7 @@ def test_events_tables_union_1_sql(diff_reporter): etu = EventsTablesUnion( start="2016-01-02", stop="2016-01-03", - tables=["events.calls"], + tables=["calls"], columns=[ "datetime", "duration", @@ -38,7 +38,7 @@ def test_events_tables_union_1_df(diff_reporter, get_dataframe): etu = EventsTablesUnion( start="2016-01-02", stop="2016-01-03", - tables=["events.calls"], + tables=["calls"], columns=[ "datetime", "duration", @@ -62,7 +62,7 @@ def test_events_tables_union_2_sql(diff_reporter): start="2016-01-03", stop="2016-01-05", hours=(7, 13), - tables=["events.calls"], + tables=["calls"], columns=[ "datetime", "duration", @@ -86,7 +86,7 @@ def test_events_tables_union_2_df(diff_reporter, get_dataframe): start="2016-01-03", stop="2016-01-05", hours=(7, 13), - tables=["events.calls"], + tables=["calls"], columns=[ "datetime", "duration", @@ -110,7 +110,7 @@ def test_events_tables_union_3_sql(diff_reporter): start="2016-01-02", stop="2016-01-04", hours=(21, 5), - tables=["events.calls"], + tables=["calls"], columns=[ "datetime", "duration", @@ -134,7 +134,7 @@ def test_events_tables_union_3_df(diff_reporter, get_dataframe): start="2016-01-02", stop="2016-01-04", hours=(21, 5), - tables=["events.calls"], + tables=["calls"], columns=[ "datetime", "duration",