Skip to content

Commit

Permalink
Optimize: speed up conversion to from products to AssetMetadata in HR…
Browse files Browse the repository at this point in the history
…LVPPMetadataCollector

Keep less data in memory and convert chunk by chunk so we van immediately save assetmetadata to GeoParquet along with the products.
  • Loading branch information
JohanKJSchreurs committed Mar 2, 2024
1 parent d024088 commit 035554d
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 56 deletions.
1 change: 0 additions & 1 deletion stacbuilder/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,6 @@ def vpp_upload(collection_path: str, max_items: int):
def vpp_upload_items(collection_path: str, max_items: int):
"""Upload a collection to the STAC API."""
settings = get_stac_api_settings()
breakpoint()
commandapi.upload_items_to_stac_api(Path(collection_path), settings=settings, max_items=max_items)


Expand Down
6 changes: 0 additions & 6 deletions stacbuilder/stacapi/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def upload_item(self, item) -> dict:
return self._items_endpoint.create_or_update(item)

def upload_items_bulk(self, collection_id: str, items: Iterable[Item]) -> None:
breakpoint()
for item in items:
self._prepare_item(item, collection_id)
item.validate()
Expand All @@ -65,17 +64,14 @@ def upload_items_bulk(self, collection_id: str, items: Iterable[Item]) -> None:
def upload_collection_and_items(
self, collection: Path | Collection, items: Path | list[Item], max_items: int = -1
) -> None:
breakpoint()
collection_out = self.upload_collection(collection)

breakpoint()
self.upload_items(collection_out, items, max_items)

def upload_items(self, collection: Path | Collection, items: Path | list[Item], max_items: int = -1) -> None:
if isinstance(collection, Path):
collection = Collection.from_file(collection)

breakpoint()
items_out: list[Item] = items or []
if not items:
_logger.info(f"Using STAC items linked to the collection: {collection.id=}")
Expand All @@ -87,15 +83,13 @@ def upload_items(self, collection: Path | Collection, items: Path | list[Item],
_logger.info(f"Number of STAC item files found: {len(item_paths)}")
items_out = (Item.from_file(path) for path in item_paths)

breakpoint()
if max_items >= 0:
_logger.info(f"User requested to limit the number of items to {max_items=}")
items_out = islice(items_out, max_items)

self.upload_items_bulk(collection.id, items_out)

def _prepare_item(self, item: Item, collection_id: str):
breakpoint()
item.collection_id = collection_id

if not item.get_links(pystac.RelType.COLLECTION):
Expand Down
113 changes: 64 additions & 49 deletions stacbuilder/terracatalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,13 @@ def get_products_as_dataframe(self) -> gpd.GeoDataFrame:
num_products_processed = 0
num_prods = catalogue.get_product_count(collection.id)
max_prods_to_process = self._max_products if self._max_products > 0 else num_prods
all_products = []

_logger.info(f"product count for coll_id {collection.id}: {num_prods}")

# If the time slots are to long you will get a terracatalogueclient.exceptions.TooManyResultsException.
for (slot_start, slot_end), prod_type in self._get_product_query_slots(frequency=self._query_by_frequency):
current_products = {}
products_as_dicts = []
current_products_ids = set()
new_products = []

num_prods_in_slot = catalogue.get_product_count(
collection.id, start=slot_start, end=slot_end, productType=prod_type
Expand All @@ -448,9 +447,12 @@ def get_products_as_dataframe(self) -> gpd.GeoDataFrame:

limit = None
if self._max_products:
prods_left = max_prods_to_process - len(all_products)
if num_prods_in_slot > prods_left:
limit = num_prods_in_slot > prods_left
prods_left = max_prods_to_process - num_products_processed
if prods_left == 0:
break
elif num_prods_in_slot > prods_left:
limit = prods_left

products = list(
catalogue.get_products(
collection.id, start=slot_start, end=slot_end, productType=prod_type, limit=limit
Expand All @@ -462,15 +464,18 @@ def get_products_as_dataframe(self) -> gpd.GeoDataFrame:
continue

for product in products:
# We should never find duplicates within the same time slot tjat we query.
# We should never find duplicates within the same time slot that we query.
# If there are duplicates, it should only happen because the time period we ask for is shorted than the
# period that the product is for, i.e. the product overlaps several time slots we retrieve.
assert product.id not in current_products
current_products[product.id] = product
if product.id in current_products_ids:
raise DataValidationError(
"Received a duplicate product within the same period. This should never happen."
)
current_products_ids.add(product.id)

# We may already have the product because we have to query the products in small time slots for
# example a day or a month.
# The time slice that the product is for may be larger than a day, even as long as a year.
# We may already have the product in the total set so far, because we have to query the
# products in small time slots, for example a day or a month.
# The time slice that the product applies to may be larger than a day, even as long as a year.
if self._df_products is not None:
where_same_prod_and_period = np.logical_and(
self._df_products["id"] == product.id,
Expand All @@ -486,65 +491,42 @@ def get_products_as_dataframe(self) -> gpd.GeoDataFrame:
)
continue

all_products.append(product)
products_as_dicts.append(self._product_to_dict(product))
new_products.append(product)

self._log_progress_message(f"Number of new products in current slot {len(products_as_dicts)}.")
if not products_as_dicts:
self._log_progress_message(f"Number of new products in current slot {len(new_products)}.")
if not new_products:
# Avoid doing unnecessary work, might add empty dataframes to the total dataframe.
continue

product_records = [{k: v for k, v in pr.items() if k != "geometry"} for pr in products_as_dicts]
product_geoms = [pr["geometry"] for pr in products_as_dicts]
gdf_products = gpd.GeoDataFrame(data=product_records, crs=EPSG_4326_LATLON, geometry=product_geoms)

gdf_products.index = gdf_products["id"]
gdf_products.sort_index()
if self._df_products is None:
self._df_products = gdf_products
else:
self._df_products = pd.concat([self._df_products, gdf_products])
self._save_intermediate_geodata(new_products)

num_products_processed = len(self._df_products)
self._save_dataframes()

percent_processed = num_products_processed / max_prods_to_process
self._log_progress_message(
f"Retrieved {num_products_processed:_} of {max_prods_to_process:_} products ({percent_processed:.1%})"
)
if num_products_processed > max_prods_to_process:
break

self._log_progress_message("START: creating AssetMetadata GeoDataFrame ...")
assets_md = [self.create_asset_metadata(p) for p in all_products]
asset_records = [{k: v for k, v in md.to_dict().items() if k != "geometry_lat_lon"} for md in assets_md]
asset_geoms = [md.geometry_lat_lon for md in assets_md]
gdf_asset_md = gpd.GeoDataFrame(data=asset_records, crs=EPSG_4326_LATLON, geometry=asset_geoms)
gdf_asset_md.index = gdf_asset_md["asset_id"]
gdf_asset_md.sort_index()
self._log_progress_message("DONE: creating AssetMetadata GeoDataFrame")

self._df_asset_metadata = gdf_asset_md
self._save_dataframes()

# Verify we have no duplicate products,
# i.e. the number of unique product IDs must be == to the number of products.
self._log_progress_message("START sanity checks: no duplicate products present and received all products ...")
product_ids = set(p.id for p in all_products)
product_ids = set(self._df_products.index)

if len(product_ids) != len(all_products):
if len(product_ids) != len(self._df_products):
raise DataValidationError(
"Sanity check failed in get_products_as_dataframe:"
+ " The result should not contain duplicate products."
+ " len(product_ids) != len(all_products)"
+ f"{len(product_ids)=} {len(all_products)=}"
+ " len(product_ids) != len(self._df_products)"
+ f" {len(product_ids)=} {len(self._df_products)=}"
)

if len(product_ids) != len(assets_md):
if len(product_ids) != len(self._df_asset_metadata.index):
raise DataValidationError(
"Sanity check failed in get_products_as_dataframe:"
+ " Each products should correspond to exactly 1 AssetMetadata instance."
+ " len(product_ids) != len(assets_md)"
+ f"{len(product_ids)=} {len(assets_md)=}"
+ " len(product_ids) != len(self._df_asset_metadata.index)"
+ f" {len(product_ids)=} {len(self._df_asset_metadata.index)=}"
)

# Check that we have processed all products, based on the product count reported by the terracatalogueclient.
Expand All @@ -553,14 +535,47 @@ def get_products_as_dataframe(self) -> gpd.GeoDataFrame:
raise DataValidationError(
"Sanity check failed in get_products_as_dataframe:"
+ "Number of products in result must be the product count reported by terracataloguiclient"
+ "len(product_ids) != num_prods"
+ f"{len(product_ids)=} product count: {num_prods=}"
+ " len(product_ids) != num_prods"
+ f" {len(product_ids)=} product count: {num_prods=}"
)
self._log_progress_message("DONE sanity checks")

self._log_progress_message("DONE: get_products_as_dataframe")
return self._df_asset_metadata

def _save_intermediate_geodata(self, new_products):
self._log_progress_message("START: saving intermediate GeoData ...")
product_records = [
{k: v for k, v in self._product_to_dict(pr).items() if k != "geometry"} for pr in new_products
]
product_geoms = [npr.geometry for npr in new_products]
gdf_products = gpd.GeoDataFrame(data=product_records, crs=EPSG_4326_LATLON, geometry=product_geoms)
gdf_products.index = gdf_products["id"]
gdf_products.sort_index()

if self._df_products is None:
self._df_products = gdf_products
else:
self._df_products = pd.concat([self._df_products, gdf_products])

self._log_progress_message("START: adding new assets to AssetMetadata GeoDataFrame ...")
assets_md = [self.create_asset_metadata(p) for p in new_products]
asset_records = [{k: v for k, v in md.to_dict().items() if k != "geometry_lat_lon"} for md in assets_md]
asset_geoms = [md.geometry_lat_lon for md in assets_md]
gdf_asset_md = gpd.GeoDataFrame(data=asset_records, crs=EPSG_4326_LATLON, geometry=asset_geoms)
gdf_asset_md.index = gdf_asset_md["asset_id"]
gdf_asset_md.sort_index()

if self._df_asset_metadata is None:
self._df_asset_metadata = gdf_asset_md
else:
self._df_asset_metadata = pd.concat([self._df_asset_metadata, gdf_asset_md])

self._log_progress_message("DONE: adding new assets to AssetMetadata GeoDataFrame")

self._save_dataframes()
self._log_progress_message("DONE: saving intermediate GeoData ...")

def _product_to_dict(self, product: tcc.Product) -> dict[str, Any]:
return {
"id": product.id,
Expand Down

0 comments on commit 035554d

Please sign in to comment.