Skip to content

Commit

Permalink
Add partition_update_enabled option
Browse files Browse the repository at this point in the history
  • Loading branch information
drnextgis committed Oct 20, 2023
1 parent 1ea6c5d commit a258911
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 23 deletions.
67 changes: 45 additions & 22 deletions src/pypgstac/python/pypgstac/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def load_partition(
partition: Partition,
items: Iterable[Dict[str, Any]],
insert_mode: Optional[Methods] = Methods.insert,
partition_update_enabled: Optional[bool] = True,
) -> None:
"""Load items data for a single partition."""
conn = self.db.connect()
Expand Down Expand Up @@ -441,12 +442,17 @@ def load_partition(
"Available modes are insert, ignore, upsert, and delsert."
f"You entered {insert_mode}.",
)
cur.execute("SELECT update_partition_stats_q(%s);",(partition.name,))
if partition_update_enabled:
cur.execute("SELECT update_partition_stats_q(%s);",(partition.name,))
logger.debug(
f"Copying data for {partition} took {time.perf_counter() - t} seconds",
)

def _partition_update(self, item: Dict[str, Any]) -> str:
def _partition_update(
self,
item: Dict[str, Any],
update_enabled: Optional[bool] = True,
) -> str:
"""Update the cached partition with the item information and return the name.
This method will mark the partition as dirty if the bounds of the partition
Expand Down Expand Up @@ -512,20 +518,24 @@ def _partition_update(self, item: Dict[str, Any]) -> str:
partition = self._partition_cache[partition_name]

if partition:
# Only update the partition if the item is outside the current bounds
if item["datetime"] < partition.datetime_range_min:
partition.datetime_range_min = item["datetime"]
partition.requires_update = True
if item["datetime"] > partition.datetime_range_max:
partition.datetime_range_max = item["datetime"]
partition.requires_update = True
if item["end_datetime"] < partition.end_datetime_range_min:
partition.end_datetime_range_min = item["end_datetime"]
partition.requires_update = True
if item["end_datetime"] > partition.end_datetime_range_max:
partition.end_datetime_range_max = item["end_datetime"]
partition.requires_update = True
if update_enabled:
# Only update the partition if the item is outside the current bounds
if item["datetime"] < partition.datetime_range_min:
partition.datetime_range_min = item["datetime"]
partition.requires_update = True
if item["datetime"] > partition.datetime_range_max:
partition.datetime_range_max = item["datetime"]
partition.requires_update = True
if item["end_datetime"] < partition.end_datetime_range_min:
partition.end_datetime_range_min = item["end_datetime"]
partition.requires_update = True
if item["end_datetime"] > partition.end_datetime_range_max:
partition.end_datetime_range_max = item["end_datetime"]
partition.requires_update = True
else:
if not update_enabled:
raise Exception(f"Partition {partition_name} does not exist.")

# No partition exists yet; create a new one from item
partition = Partition(
name=partition_name,
Expand All @@ -541,7 +551,11 @@ def _partition_update(self, item: Dict[str, Any]) -> str:

return partition_name

def read_dehydrated(self, file: Union[Path, str] = "stdin") -> Generator:
def read_dehydrated(
self,
file: Union[Path, str] = "stdin",
partition_update_enabled: Optional[bool] = True,
) -> Generator:
if file is None:
file = "stdin"
if isinstance(file, str):
Expand Down Expand Up @@ -572,15 +586,21 @@ def read_dehydrated(self, file: Union[Path, str] = "stdin") -> Generator:
item[field] = content_value
else:
item[field] = tab_split[i]
item["partition"] = self._partition_update(item)
item["partition"] = self._partition_update(
item,
partition_update_enabled,
)
yield item

def read_hydrated(
self, file: Union[Path, str, Iterator[Any]] = "stdin",
self,
file: Union[Path, str,
Iterator[Any]] = "stdin",
partition_update_enabled: Optional[bool] = True,
) -> Generator:
for line in read_json(file):
item = self.format_item(line)
item["partition"] = self._partition_update(item)
item["partition"] = self._partition_update(item, partition_update_enabled)
yield item

def load_items(
Expand All @@ -589,6 +609,7 @@ def load_items(
insert_mode: Optional[Methods] = Methods.insert,
dehydrated: Optional[bool] = False,
chunksize: Optional[int] = 10000,
partition_update_enabled: Optional[bool] = True,
) -> None:
"""Load items json records."""
self.check_version()
Expand All @@ -599,15 +620,17 @@ def load_items(
self._partition_cache = {}

if dehydrated and isinstance(file, str):
items = self.read_dehydrated(file)
items = self.read_dehydrated(file, partition_update_enabled)
else:
items = self.read_hydrated(file)
items = self.read_hydrated(file, partition_update_enabled)

for chunkin in chunked_iterable(items, chunksize):
chunk = list(chunkin)
chunk.sort(key=lambda x: x["partition"])
for k, g in itertools.groupby(chunk, lambda x: x["partition"]):
self.load_partition(self._partition_cache[k], g, insert_mode)
self.load_partition(
self._partition_cache[k], g, insert_mode, partition_update_enabled,
)

logger.debug(f"Adding data to database took {time.perf_counter() - t} seconds.")

Expand Down
5 changes: 4 additions & 1 deletion src/pypgstac/python/pypgstac/pypgstac.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,16 @@ def load(
method: Optional[Methods] = Methods.insert,
dehydrated: Optional[bool] = False,
chunksize: Optional[int] = 10000,
partition_update_enabled: Optional[bool] = True,
) -> None:
"""Load collections or items into PGStac."""
loader = Loader(db=self._db)
if table == "collections":
loader.load_collections(file, method)
if table == "items":
loader.load_items(file, method, dehydrated, chunksize)
loader.load_items(
file, method, dehydrated, chunksize, partition_update_enabled,
)

def loadextensions(self) -> None:
conn = self._db.connect()
Expand Down
17 changes: 17 additions & 0 deletions src/pypgstac/tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,20 @@ def test_load_items_nopartitionconstraint_succeeds(loader: Loader) -> None:
str(TEST_ITEMS),
insert_mode=Methods.upsert,
)


def test_load_items_when_partition_creation_disabled(loader: Loader) -> None:
"""
Test pypgstac items loader raises an exception when partition
does not exist and partition creation is disabled.
"""
loader.load_collections(
str(TEST_COLLECTIONS_JSON),
insert_mode=Methods.insert,
)
with pytest.raises(ValueError):
loader.load_items(
str(TEST_ITEMS),
insert_mode=Methods.insert,
partition_update_enabled=False,
)

0 comments on commit a258911

Please sign in to comment.