From a258911b37de838032953898588e4a98a7fc8b54 Mon Sep 17 00:00:00 2001 From: Denis Rykov Date: Tue, 17 Oct 2023 21:47:12 +0200 Subject: [PATCH] Add partition_update_enabled option --- src/pypgstac/python/pypgstac/load.py | 67 ++++++++++++++++-------- src/pypgstac/python/pypgstac/pypgstac.py | 5 +- src/pypgstac/tests/test_load.py | 17 ++++++ 3 files changed, 66 insertions(+), 23 deletions(-) diff --git a/src/pypgstac/python/pypgstac/load.py b/src/pypgstac/python/pypgstac/load.py index 1ce62096..0efa8efa 100644 --- a/src/pypgstac/python/pypgstac/load.py +++ b/src/pypgstac/python/pypgstac/load.py @@ -270,6 +270,7 @@ def load_partition( partition: Partition, items: Iterable[Dict[str, Any]], insert_mode: Optional[Methods] = Methods.insert, + partition_update_enabled: Optional[bool] = True, ) -> None: """Load items data for a single partition.""" conn = self.db.connect() @@ -441,12 +442,17 @@ def load_partition( "Available modes are insert, ignore, upsert, and delsert." f"You entered {insert_mode}.", ) - cur.execute("SELECT update_partition_stats_q(%s);",(partition.name,)) + if partition_update_enabled: + cur.execute("SELECT update_partition_stats_q(%s);",(partition.name,)) logger.debug( f"Copying data for {partition} took {time.perf_counter() - t} seconds", ) - def _partition_update(self, item: Dict[str, Any]) -> str: + def _partition_update( + self, + item: Dict[str, Any], + update_enabled: Optional[bool] = True, + ) -> str: """Update the cached partition with the item information and return the name. This method will mark the partition as dirty if the bounds of the partition @@ -512,20 +518,24 @@ def _partition_update(self, item: Dict[str, Any]) -> str: partition = self._partition_cache[partition_name] if partition: - # Only update the partition if the item is outside the current bounds - if item["datetime"] < partition.datetime_range_min: - partition.datetime_range_min = item["datetime"] - partition.requires_update = True - if item["datetime"] > partition.datetime_range_max: - partition.datetime_range_max = item["datetime"] - partition.requires_update = True - if item["end_datetime"] < partition.end_datetime_range_min: - partition.end_datetime_range_min = item["end_datetime"] - partition.requires_update = True - if item["end_datetime"] > partition.end_datetime_range_max: - partition.end_datetime_range_max = item["end_datetime"] - partition.requires_update = True + if update_enabled: + # Only update the partition if the item is outside the current bounds + if item["datetime"] < partition.datetime_range_min: + partition.datetime_range_min = item["datetime"] + partition.requires_update = True + if item["datetime"] > partition.datetime_range_max: + partition.datetime_range_max = item["datetime"] + partition.requires_update = True + if item["end_datetime"] < partition.end_datetime_range_min: + partition.end_datetime_range_min = item["end_datetime"] + partition.requires_update = True + if item["end_datetime"] > partition.end_datetime_range_max: + partition.end_datetime_range_max = item["end_datetime"] + partition.requires_update = True else: + if not update_enabled: + raise Exception(f"Partition {partition_name} does not exist.") + # No partition exists yet; create a new one from item partition = Partition( name=partition_name, @@ -541,7 +551,11 @@ def _partition_update(self, item: Dict[str, Any]) -> str: return partition_name - def read_dehydrated(self, file: Union[Path, str] = "stdin") -> Generator: + def read_dehydrated( + self, + file: Union[Path, str] = "stdin", + partition_update_enabled: Optional[bool] = True, + ) -> Generator: if file is None: file = "stdin" if isinstance(file, str): @@ -572,15 +586,21 @@ def read_dehydrated(self, file: Union[Path, str] = "stdin") -> Generator: item[field] = content_value else: item[field] = tab_split[i] - item["partition"] = self._partition_update(item) + item["partition"] = self._partition_update( + item, + partition_update_enabled, + ) yield item def read_hydrated( - self, file: Union[Path, str, Iterator[Any]] = "stdin", + self, + file: Union[Path, str, + Iterator[Any]] = "stdin", + partition_update_enabled: Optional[bool] = True, ) -> Generator: for line in read_json(file): item = self.format_item(line) - item["partition"] = self._partition_update(item) + item["partition"] = self._partition_update(item, partition_update_enabled) yield item def load_items( @@ -589,6 +609,7 @@ def load_items( insert_mode: Optional[Methods] = Methods.insert, dehydrated: Optional[bool] = False, chunksize: Optional[int] = 10000, + partition_update_enabled: Optional[bool] = True, ) -> None: """Load items json records.""" self.check_version() @@ -599,15 +620,17 @@ def load_items( self._partition_cache = {} if dehydrated and isinstance(file, str): - items = self.read_dehydrated(file) + items = self.read_dehydrated(file, partition_update_enabled) else: - items = self.read_hydrated(file) + items = self.read_hydrated(file, partition_update_enabled) for chunkin in chunked_iterable(items, chunksize): chunk = list(chunkin) chunk.sort(key=lambda x: x["partition"]) for k, g in itertools.groupby(chunk, lambda x: x["partition"]): - self.load_partition(self._partition_cache[k], g, insert_mode) + self.load_partition( + self._partition_cache[k], g, insert_mode, partition_update_enabled, + ) logger.debug(f"Adding data to database took {time.perf_counter() - t} seconds.") diff --git a/src/pypgstac/python/pypgstac/pypgstac.py b/src/pypgstac/python/pypgstac/pypgstac.py index b94c3ec9..d2a860d7 100644 --- a/src/pypgstac/python/pypgstac/pypgstac.py +++ b/src/pypgstac/python/pypgstac/pypgstac.py @@ -63,13 +63,16 @@ def load( method: Optional[Methods] = Methods.insert, dehydrated: Optional[bool] = False, chunksize: Optional[int] = 10000, + partition_update_enabled: Optional[bool] = True, ) -> None: """Load collections or items into PGStac.""" loader = Loader(db=self._db) if table == "collections": loader.load_collections(file, method) if table == "items": - loader.load_items(file, method, dehydrated, chunksize) + loader.load_items( + file, method, dehydrated, chunksize, partition_update_enabled, + ) def loadextensions(self) -> None: conn = self._db.connect() diff --git a/src/pypgstac/tests/test_load.py b/src/pypgstac/tests/test_load.py index 49c5d0ef..2f501318 100644 --- a/src/pypgstac/tests/test_load.py +++ b/src/pypgstac/tests/test_load.py @@ -441,3 +441,20 @@ def test_load_items_nopartitionconstraint_succeeds(loader: Loader) -> None: str(TEST_ITEMS), insert_mode=Methods.upsert, ) + + +def test_load_items_when_partition_creation_disabled(loader: Loader) -> None: + """ + Test pypgstac items loader raises an exception when partition + does not exist and partition creation is disabled. + """ + loader.load_collections( + str(TEST_COLLECTIONS_JSON), + insert_mode=Methods.insert, + ) + with pytest.raises(ValueError): + loader.load_items( + str(TEST_ITEMS), + insert_mode=Methods.insert, + partition_update_enabled=False, + )