diff --git a/ingestify/application/dataset_store.py b/ingestify/application/dataset_store.py index 66b7d72..f7a5a68 100644 --- a/ingestify/application/dataset_store.py +++ b/ingestify/application/dataset_store.py @@ -66,6 +66,13 @@ def get_dataset_collection( if isinstance(selector, dict): # By-pass the build as we don't want to specify data_spec_versions here... (for now) selector = Selector(selector) + elif isinstance(selector, list): + if not selector: + return DatasetCollection() + + if isinstance(selector[0], dict): + # Convert all selector dicts to Selectors + selector = [Selector(_) for _ in selector] dataset_collection = self.dataset_repository.get_dataset_collection( bucket=self.bucket, @@ -248,7 +255,7 @@ def load_files( dataset: Dataset, data_feed_keys: Optional[List[str]] = None, lazy: bool = False, - auto_rewind: bool = True + auto_rewind: bool = True, ) -> FileCollection: current_revision = dataset.current_revision files = {} diff --git a/ingestify/application/loader.py b/ingestify/application/loader.py index 10b58de..c3532b1 100644 --- a/ingestify/application/loader.py +++ b/ingestify/application/loader.py @@ -4,7 +4,7 @@ from typing import List from ingestify.domain.models import Dataset, Identifier, Selector, Source, Task, TaskSet -from ingestify.utils import map_in_pool +from ingestify.utils import map_in_pool, TaskExecutor from .dataset_store import DatasetStore from ..domain.models.data_spec_version_collection import DataSpecVersionCollection @@ -94,8 +94,6 @@ def add_extract_job(self, extract_job: ExtractJob): self.extract_jobs.append(extract_job) def collect_and_run(self): - task_set = TaskSet() - total_dataset_count = 0 # First collect all selectors, before discovering datasets @@ -155,77 +153,96 @@ def collect_and_run(self): else: selectors[key] = (extract_job, selector) + def run_task(task): + logger.info(f"Running task {task}") + task.run() + + task_executor = TaskExecutor() + for extract_job, selector in selectors.values(): logger.debug( f"Discovering datasets from {extract_job.source.__class__.__name__} using selector {selector}" ) - dataset_identifiers = [ - Identifier.create_from(selector, **identifier) - # We have to pass the data_spec_versions here as a Source can add some - # extra data to the identifier which is retrieved in a certain data format - for identifier in extract_job.source.discover_datasets( - dataset_type=extract_job.dataset_type, - data_spec_versions=selector.data_spec_versions, - **selector.filtered_attributes, - ) - ] - - task_subset = TaskSet() - dataset_collection = self.store.get_dataset_collection( + dataset_collection_metadata = self.store.get_dataset_collection( dataset_type=extract_job.dataset_type, - provider=extract_job.source.provider, + data_spec_versions=selector.data_spec_versions, selector=selector, + metadata_only=True, + ).metadata + + # There are two different, but similar flows here: + # 1. The discover_datasets returns a list, and the entire list can be processed at once + # 2. The discover_datasets returns an iterator of batches, in this case we need to process each batch + discovered_datasets = extract_job.source.discover_datasets( + dataset_type=extract_job.dataset_type, + data_spec_versions=selector.data_spec_versions, + dataset_collection_metadata=dataset_collection_metadata, + **selector.filtered_attributes, ) - skip_count = 0 - total_dataset_count += len(dataset_identifiers) - - for dataset_identifier in dataset_identifiers: - if dataset := dataset_collection.get(dataset_identifier): - if extract_job.fetch_policy.should_refetch( - dataset, dataset_identifier - ): - task_subset.add( - UpdateDatasetTask( - source=extract_job.source, - dataset=dataset, # Current dataset from the database - dataset_identifier=dataset_identifier, # Most recent dataset_identifier - data_spec_versions=selector.data_spec_versions, - store=self.store, + if isinstance(discovered_datasets, list): + batches = [discovered_datasets] + else: + batches = discovered_datasets + + for batch in batches: + dataset_identifiers = [ + Identifier.create_from(selector, **identifier) + # We have to pass the data_spec_versions here as a Source can add some + # extra data to the identifier which is retrieved in a certain data format + for identifier in batch + ] + + # Load all available datasets based on the discovered dataset identifiers + dataset_collection = self.store.get_dataset_collection( + dataset_type=extract_job.dataset_type, + provider=extract_job.source.provider, + selector=dataset_identifiers, + ) + + skip_count = 0 + total_dataset_count += len(dataset_identifiers) + + task_set = TaskSet() + for dataset_identifier in dataset_identifiers: + if dataset := dataset_collection.get(dataset_identifier): + if extract_job.fetch_policy.should_refetch( + dataset, dataset_identifier + ): + task_set.add( + UpdateDatasetTask( + source=extract_job.source, + dataset=dataset, # Current dataset from the database + dataset_identifier=dataset_identifier, # Most recent dataset_identifier + data_spec_versions=selector.data_spec_versions, + store=self.store, + ) ) - ) + else: + skip_count += 1 else: - skip_count += 1 - else: - if extract_job.fetch_policy.should_fetch(dataset_identifier): - task_subset.add( - CreateDatasetTask( - source=extract_job.source, - dataset_type=extract_job.dataset_type, - dataset_identifier=dataset_identifier, - data_spec_versions=selector.data_spec_versions, - store=self.store, + if extract_job.fetch_policy.should_fetch(dataset_identifier): + task_set.add( + CreateDatasetTask( + source=extract_job.source, + dataset_type=extract_job.dataset_type, + dataset_identifier=dataset_identifier, + data_spec_versions=selector.data_spec_versions, + store=self.store, + ) ) - ) - else: - skip_count += 1 - - logger.info( - f"Discovered {len(dataset_identifiers)} datasets from {extract_job.source.__class__.__name__} " - f"using selector {selector} => {len(task_subset)} tasks. {skip_count} skipped." - ) + else: + skip_count += 1 - task_set += task_subset + logger.info( + f"Discovered {len(dataset_identifiers)} datasets from {extract_job.source.__class__.__name__} " + f"using selector {selector} => {len(task_set)} tasks. {skip_count} skipped." + ) - if len(task_set): - processes = cpu_count() - logger.info(f"Scheduled {len(task_set)} tasks. With {processes} processes") + task_executor.run(run_task, task_set) + logger.info(f"Scheduled {len(task_set)} tasks") - def run_task(task): - logger.info(f"Running task {task}") - task.run() + task_executor.join() - map_in_pool(run_task, task_set) - else: - logger.info("Nothing to do.") + logger.info("Done") diff --git a/ingestify/domain/models/dataset/collection.py b/ingestify/domain/models/dataset/collection.py index cc17d96..81d8f6b 100644 --- a/ingestify/domain/models/dataset/collection.py +++ b/ingestify/domain/models/dataset/collection.py @@ -1,16 +1,26 @@ -from typing import List +from typing import List, Optional +from .collection_metadata import DatasetCollectionMetadata from .dataset import Dataset from .identifier import Identifier class DatasetCollection: - def __init__(self, datasets: List[Dataset] = None): + def __init__( + self, + metadata: Optional[DatasetCollectionMetadata] = None, + datasets: Optional[List[Dataset]] = None, + ): datasets = datasets or [] + # TODO: this fails when datasets contains different dataset_types with overlapping identifiers self.datasets: dict[str, Dataset] = { dataset.identifier.key: dataset for dataset in datasets } + self.metadata = metadata + + def loaded(self): + return self.metadata.count == len(self.datasets) def get(self, dataset_identifier: Identifier) -> Dataset: return self.datasets.get(dataset_identifier.key) diff --git a/ingestify/domain/models/dataset/collection_metadata.py b/ingestify/domain/models/dataset/collection_metadata.py new file mode 100644 index 0000000..6a90f98 --- /dev/null +++ b/ingestify/domain/models/dataset/collection_metadata.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + + +@dataclass +class DatasetCollectionMetadata: + last_modified: Optional[datetime] + count: int diff --git a/ingestify/domain/models/dataset/dataset_repository.py b/ingestify/domain/models/dataset/dataset_repository.py index 9352d7b..a0a39aa 100644 --- a/ingestify/domain/models/dataset/dataset_repository.py +++ b/ingestify/domain/models/dataset/dataset_repository.py @@ -18,7 +18,8 @@ def get_dataset_collection( dataset_type: Optional[str] = None, dataset_id: Optional[Union[str, List[str]]] = None, provider: Optional[str] = None, - selector: Optional[Selector] = None, + selector: Optional[Union[Selector, List[Selector]]] = None, + metadata_only: bool = False, ) -> DatasetCollection: pass diff --git a/ingestify/domain/models/dataset/file.py b/ingestify/domain/models/dataset/file.py index 2b1916c..77136ee 100644 --- a/ingestify/domain/models/dataset/file.py +++ b/ingestify/domain/models/dataset/file.py @@ -28,9 +28,9 @@ class DraftFile: def from_input( cls, file_, + data_serialization_format="txt", data_feed_key=None, data_spec_version=None, - data_serialization_format=None, modified_at=None, ): # Pass-through for these types diff --git a/ingestify/domain/models/dataset/file_collection.py b/ingestify/domain/models/dataset/file_collection.py index 63715f0..19e58e9 100644 --- a/ingestify/domain/models/dataset/file_collection.py +++ b/ingestify/domain/models/dataset/file_collection.py @@ -13,7 +13,7 @@ def get_file( self, data_feed_key: Optional[str] = None, data_spec_version: Optional[str] = None, - auto_rewind: Optional[bool] = None + auto_rewind: Optional[bool] = None, ) -> Optional[LoadedFile]: if not data_feed_key and not data_spec_version: raise ValueError( diff --git a/ingestify/domain/models/source.py b/ingestify/domain/models/source.py index 2fdd715..1780670 100644 --- a/ingestify/domain/models/source.py +++ b/ingestify/domain/models/source.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Iterable, Iterator, Union # from ingestify.utils import ComponentFactory, ComponentRegistry from . import DraftFile from .data_spec_version_collection import DataSpecVersionCollection from .dataset import Identifier, Revision +from .dataset.collection_metadata import DatasetCollectionMetadata class Source(ABC): @@ -24,8 +25,12 @@ def provider(self) -> str: @abstractmethod def discover_datasets( - self, dataset_type: str, data_spec_versions: DataSpecVersionCollection, **kwargs - ) -> List[Dict]: + self, + dataset_type: str, + data_spec_versions: DataSpecVersionCollection, + dataset_collection_metadata: DatasetCollectionMetadata, + **kwargs + ) -> Union[List[Dict], Iterator[List[Dict]]]: pass @abstractmethod diff --git a/ingestify/infra/store/dataset/local_dataset_repository.py b/ingestify/infra/store/dataset/local_dataset_repository.py index 04881bf..0bd00b1 100644 --- a/ingestify/infra/store/dataset/local_dataset_repository.py +++ b/ingestify/infra/store/dataset/local_dataset_repository.py @@ -33,6 +33,9 @@ def supports(cls, url: str) -> bool: def __init__(self, url: str): self.base_dir = Path(url[7:]) + raise DeprecationWarning( + "This Repository should not be used. Better use SqlAlchemyDatasetRepository with a local sqlite database." + ) def get_dataset_collection( self, diff --git a/ingestify/infra/store/dataset/sqlalchemy/repository.py b/ingestify/infra/store/dataset/sqlalchemy/repository.py index 0523f24..41f1ec8 100644 --- a/ingestify/infra/store/dataset/sqlalchemy/repository.py +++ b/ingestify/infra/store/dataset/sqlalchemy/repository.py @@ -2,11 +2,12 @@ import uuid from typing import Optional, Union, List -from sqlalchemy import create_engine, func, text +from sqlalchemy import create_engine, func, text, tuple_ from sqlalchemy.engine import make_url from sqlalchemy.exc import NoSuchModuleError from sqlalchemy.orm import Session, joinedload +from ingestify.domain import File from ingestify.domain.models import ( Dataset, DatasetCollection, @@ -14,6 +15,9 @@ Identifier, Selector, ) +from ingestify.domain.models.dataset.collection_metadata import ( + DatasetCollectionMetadata, +) from .mapping import dataset_table, metadata @@ -105,19 +109,16 @@ def __setstate__(self, state): self.url = state["url"] self._init_engine() - def get_dataset_collection( + def _filter_query( self, + query, bucket: str, dataset_type: Optional[str] = None, provider: Optional[str] = None, dataset_id: Optional[Union[str, List[str]]] = None, - selector: Optional[Selector] = None, - ) -> DatasetCollection: - query = ( - self.session.query(Dataset) - .options(joinedload(Dataset.revisions)) - .filter(Dataset.bucket == bucket) - ) + selector: Optional[Union[Selector, List[Selector]]] = None, + ): + query = query.filter(Dataset.bucket == bucket) if dataset_type: query = query.filter(Dataset.dataset_type == dataset_type) if provider: @@ -135,27 +136,93 @@ def get_dataset_collection( dialect = self.session.bind.dialect.name - where, selector = selector.split("where") + if not isinstance(selector, list): + where, selector = selector.split("where") + else: + where = None + if selector: - for k, v in selector.filtered_attributes.items(): + if isinstance(selector, list): + selectors = selector + else: + selectors = [selector] + + if not selectors: + raise ValueError("Selectors must contain at least one item") + + keys = list(selectors[0].filtered_attributes.keys()) + + columns = [] + first_selector = selectors[0].filtered_attributes + + # Create a query like this: + # SELECT * FROM dataset WHERE (column1, column2, column3) IN ((1, 2, 3), (4, 5, 6), (7, 8, 9)) + for k in keys: if dialect == "postgresql": column = dataset_table.c.identifier[k] + + # Take the value from the first selector to determine the type. + # TODO: check all selectors to determine the type + v = first_selector[k] if isint(v): column = column.as_integer() elif isfloat(v): column = column.as_float() else: column = column.as_string() - query = query.filter(column == v) else: - query = query.filter( - func.json_extract(Dataset.identifier, f"$.{k}") == v - ) + column = func.json_extract(Dataset.identifier, f"$.{k}") + columns.append(column) + + values = [] + for selector in selectors: + filtered_attributes = selector.filtered_attributes + values.append(tuple([filtered_attributes[k] for k in keys])) + + query = query.filter(tuple_(*columns).in_(values)) if where: query = query.filter(text(where)) + return query + + def get_dataset_collection( + self, + bucket: str, + dataset_type: Optional[str] = None, + provider: Optional[str] = None, + dataset_id: Optional[Union[str, List[str]]] = None, + selector: Optional[Union[Selector, List[Selector]]] = None, + metadata_only: bool = False, + ) -> DatasetCollection: + + def apply_query_filter(query): + return self._filter_query( + query, + bucket=bucket, + dataset_type=dataset_type, + provider=provider, + dataset_id=dataset_id, + selector=selector, + ) + + if not metadata_only: + dataset_query = apply_query_filter( + self.session.query(Dataset).options(joinedload(Dataset.revisions)) + ) + datasets = list(dataset_query) + else: + datasets = [] + + metadata_result = list( + apply_query_filter( + self.session.query(func.max(File.modified_at), func.count()) + ) + )[0] + dataset_collection_metadata = DatasetCollectionMetadata( + last_modified=metadata_result[0], count=metadata_result[1] + ) - return DatasetCollection(list(query)) + return DatasetCollection(dataset_collection_metadata, datasets) def save(self, bucket: str, dataset: Dataset): # Just make sure diff --git a/ingestify/main.py b/ingestify/main.py index b1d6b80..4ab378a 100644 --- a/ingestify/main.py +++ b/ingestify/main.py @@ -195,7 +195,7 @@ def get_engine(config_file, bucket: Optional[str] = None) -> IngestionEngine: import_job = ExtractJob( source=sources[job["source"]], - dataset_type=job.get("dataset_type"), + dataset_type=job["dataset_type"], selectors=selectors, fetch_policy=fetch_policy, data_spec_versions=data_spec_versions, diff --git a/ingestify/tests/config.yaml b/ingestify/tests/config.yaml new file mode 100644 index 0000000..dbaf855 --- /dev/null +++ b/ingestify/tests/config.yaml @@ -0,0 +1,8 @@ +main: + # Cannot use in memory data because database is shared between processes + dataset_url: !ENV "sqlite:///${TEST_DIR}/main.db" + file_url: !ENV file://${TEST_DIR}/data + default_bucket: main + +sources: {} +extract_jobs: [] diff --git a/ingestify/tests/conftest.py b/ingestify/tests/conftest.py new file mode 100644 index 0000000..2cee097 --- /dev/null +++ b/ingestify/tests/conftest.py @@ -0,0 +1,12 @@ +import tempfile + +import pytest +import os + + +@pytest.fixture(scope="function", autouse=True) +def datastore_dir(): + with tempfile.TemporaryDirectory() as tmpdirname: + os.environ["TEST_DIR"] = tmpdirname + os.environ["INGESTIFY_RUN_EAGER"] = "true" + yield tmpdirname diff --git a/ingestify/tests/test_engine.py b/ingestify/tests/test_engine.py new file mode 100644 index 0000000..b7cf7bc --- /dev/null +++ b/ingestify/tests/test_engine.py @@ -0,0 +1,214 @@ +from datetime import datetime +from typing import Optional + +import pytz + +from ingestify import Source +from ingestify.application.ingestion_engine import IngestionEngine +from ingestify.domain import ( + Identifier, + Selector, + DataSpecVersionCollection, + DraftFile, + Revision, +) +from ingestify.domain.models.extract_job import ExtractJob +from ingestify.domain.models.fetch_policy import FetchPolicy +from ingestify.main import get_engine + + +def add_extract_job(engine: IngestionEngine, source: Source, **selector): + data_spec_versions = DataSpecVersionCollection.from_dict({"default": {"v1"}}) + + engine.add_extract_job( + ExtractJob( + source=source, + fetch_policy=FetchPolicy(), + selectors=[Selector.build(selector, data_spec_versions=data_spec_versions)], + dataset_type="match", + data_spec_versions=data_spec_versions, + ) + ) + + +class SimpleFakeSource(Source): + @property + def provider(self) -> str: + return "fake" + + def discover_datasets( + self, + dataset_type: str, + data_spec_versions: DataSpecVersionCollection, + competition_id, + season_id, + **kwargs + ): + return [ + dict( + competition_id=competition_id, + season_id=season_id, + _name="Test Dataset", + _last_modified=datetime.now(pytz.utc), + ) + ] + + def fetch_dataset_files( + self, + dataset_type: str, + identifier: Identifier, + data_spec_versions: DataSpecVersionCollection, + current_revision: Optional[Revision], + ): + if current_revision: + return { + "file1": DraftFile.from_input( + "different_content", + ), + "file2": DraftFile.from_input("some_content" + identifier.key), + } + else: + return { + "file1": DraftFile.from_input( + "content1", + ), + "file2": DraftFile.from_input("some_content" + identifier.key), + } + + +class BatchSource(Source): + def __init__(self, name, callback): + super().__init__(name) + self.callback = callback + self.should_stop = False + self.idx = 0 + + @property + def provider(self) -> str: + return "fake" + + def discover_datasets( + self, + dataset_type: str, + data_spec_versions: DataSpecVersionCollection, + competition_id, + season_id, + **kwargs + ): + while not self.should_stop: + items = [] + for i in range(10): + match_id = self.idx + self.idx += 1 + item = dict( + competition_id=competition_id, + season_id=season_id, + match_id=match_id, + _name="Test Dataset", + _last_modified=datetime.now(pytz.utc), + ) + items.append(item) + yield items + self.callback and self.callback(self.idx) + + def fetch_dataset_files( + self, + dataset_type: str, + identifier: Identifier, + data_spec_versions: DataSpecVersionCollection, + current_revision: Optional[Revision], + ): + if current_revision: + return { + "file1": DraftFile.from_input( + "different_content", + ), + "file2": DraftFile.from_input("some_content" + identifier.key), + } + else: + return { + "file1": DraftFile.from_input( + "content1", + ), + "file2": DraftFile.from_input("some_content" + identifier.key), + } + + +def test_engine(): + engine = get_engine("config.yaml", "main") + + add_extract_job( + engine, SimpleFakeSource("fake-source"), competition_id=1, season_id=2 + ) + engine.load() + datasets = engine.store.get_dataset_collection() + assert len(datasets) == 1 + + dataset = datasets.first() + assert dataset.identifier == Identifier(competition_id=1, season_id=2) + assert len(dataset.revisions) == 1 + + engine.load() + datasets = engine.store.get_dataset_collection() + assert len(datasets) == 1 + + dataset = datasets.first() + assert dataset.identifier == Identifier(competition_id=1, season_id=2) + assert len(dataset.revisions) == 2 + assert len(dataset.revisions[0].modified_files) == 2 + assert len(dataset.revisions[1].modified_files) == 1 + + add_extract_job( + engine, SimpleFakeSource("fake-source"), competition_id=1, season_id=3 + ) + engine.load() + + datasets = engine.store.get_dataset_collection() + assert len(datasets) == 2 + + datasets = engine.store.get_dataset_collection(season_id=3) + assert len(datasets) == 1 + + +def test_iterator_source(): + """Test when a Source returns a Iterator to do Batch processing. + + Every batch must be executed right away. + """ + engine = get_engine("config.yaml", "main") + + batch_source = None + + def callback(idx): + nonlocal batch_source + datasets = engine.store.get_dataset_collection() + assert len(datasets) == idx + + if idx == 100: + batch_source.should_stop = True + + batch_source = BatchSource("fake-source", callback) + + add_extract_job(engine, batch_source, competition_id=1, season_id=2) + engine.load() + + datasets = engine.store.get_dataset_collection() + assert len(datasets) == 100 + for dataset in datasets: + assert len(dataset.revisions) == 1 + + # Now lets run again. This should create new revisions + batch_source.idx = 0 + batch_source.should_stop = False + + def callback(idx): + if idx == 100: + batch_source.should_stop = True + + batch_source.callback = callback + + engine.load() + datasets = engine.store.get_dataset_collection() + assert len(datasets) == 100 + for dataset in datasets: + assert len(dataset.revisions) == 2 diff --git a/ingestify/utils.py b/ingestify/utils.py index 30d4551..a0bff67 100644 --- a/ingestify/utils.py +++ b/ingestify/utils.py @@ -145,6 +145,10 @@ def matches(self, attributes: Dict) -> bool: def filtered_attributes(self): return {k: v for k, v in self.attributes.items() if not k.startswith("_")} + def __eq__(self, other): + if isinstance(other, AttributeBag): + return self.key == other.key + def __hash__(self): return hash(self.key) @@ -193,3 +197,37 @@ def map_in_pool(func, iterable, processes=0): return pool.map( cloud_unpack_and_call, ((wrapped_fn, item) for item in iterable) ) + + +class SyncPool: + def map_async(self, func, iterable): + return [func(item) for item in iterable] + + def join(self): + return True + + +class TaskExecutor: + def __init__(self, processes=0): + if os.environ.get("INGESTIFY_RUN_EAGER") == "true": + pool = SyncPool() + else: + if not processes: + processes = int(os.environ.get("INGESTIFY_CONCURRENCY", "0")) + + if "fork" in get_all_start_methods(): + ctx = get_context("fork") + else: + ctx = get_context("spawn") + + pool = ctx.Pool(processes or cpu_count()) + self.pool = pool + + def run(self, func, iterable): + wrapped_fn = cloudpickle.dumps(func) + self.pool.map_async( + cloud_unpack_and_call, ((wrapped_fn, item) for item in iterable) + ) + + def join(self): + self.pool.join() diff --git a/setup.py b/setup.py index cae8de7..e8fb567 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ def package_files(directory): paths = [] - for (path, directories, filenames) in os.walk(directory): + for path, directories, filenames in os.walk(directory): for filename in filenames: paths.append(os.path.join("..", path, filename)) return paths @@ -37,8 +37,9 @@ def setup_package(): "jinja2", "python-dotenv", "pyaml_env", - "boto3" + "boto3", ], + extras_require={"test": ["pytest>=6.2.5,<7"]}, )