From 62c9c3abb3e36c3f2baaa9f2ba0e65a55193f89f Mon Sep 17 00:00:00 2001 From: Theophile du Laz Date: Thu, 17 Aug 2023 18:00:36 +0200 Subject: [PATCH] Catalog ingestion: verify id uniqueness in parallel for speed (#236) * add warning when fetching matchfiles list, verify that ids are unique when ingesting catalogs * recreate mongodb connection * read files in parallel to find all ids quicker --- kowalski/ingesters/ingest_catalog.py | 93 ++++++++++++++++++---------- 1 file changed, 61 insertions(+), 32 deletions(-) diff --git a/kowalski/ingesters/ingest_catalog.py b/kowalski/ingesters/ingest_catalog.py index b0e32411..e6576493 100644 --- a/kowalski/ingesters/ingest_catalog.py +++ b/kowalski/ingesters/ingest_catalog.py @@ -8,6 +8,7 @@ import random import time import traceback +from copy import deepcopy from typing import Sequence import fire @@ -337,6 +338,7 @@ def process_file(argument_list: Sequence): if dec_col not in names: log(f"Provided DEC column {dec_col} not found") return + batch = [] def convert_nparray_to_list(value): @@ -404,7 +406,6 @@ def convert_nparray_to_list(value): "coordinates": _radec_geojson, } batch.append(document) - total_good_documents += 1 except Exception as exception: total_bad_documents += 1 log(str(exception)) @@ -426,12 +427,14 @@ def convert_nparray_to_list(value): {"_id": {"$in": ids}} ) if count == len(batch): + total_good_documents += len(batch) break n_retries += 1 time.sleep(6) mongo.close() mongo = get_mongo_client() else: + total_good_documents += len(batch) break if n_retries == 10: @@ -456,12 +459,14 @@ def convert_nparray_to_list(value): {"_id": {"$in": ids}} ) if count == len(batch): + total_good_documents += len(batch) break n_retries += 1 time.sleep(6) mongo.close() mongo = get_mongo_client() else: + total_good_documents += len(batch) break if n_retries == 10: @@ -501,36 +506,54 @@ def convert_nparray_to_list(value): return total_good_documents, total_bad_documents -def verify_ids(files, id_col, format): - ids_per_file = {} +def get_file_ids(argument_list: Sequence): + file, id_col, format = argument_list - for file in files: - if format == "fits": - with fits.open(file, cache=False) as hdulist: - nhdu = 1 - names = hdulist[nhdu].columns.names - # first check if the id_col is in the names - if id_col not in names: - raise Exception( - f"Provided ID column {id_col} not found in file {file}" - ) - dataframe = pd.DataFrame(np.asarray(hdulist[nhdu].data), columns=names) - ids_per_file[file] = list(dataframe[id_col]) - elif format == "csv": - dataframe = pd.read_csv(file) - if id_col not in dataframe.columns: - raise Exception(f"Provided ID column {id_col} not found in file {file}") - ids_per_file[file] = list(dataframe[id_col]) - elif format == "parquet": - df = pq.read_table(file).to_pandas() - for name in list(df.columns): - if name.startswith("_"): - df.rename(columns={name: name[1:]}, inplace=True) - if id_col not in df.columns: + ids = [] + if format == "fits": + with fits.open(file, cache=False) as hdulist: + nhdu = 1 + names = hdulist[nhdu].columns.names + # first check if the id_col is in the names + if id_col not in names: raise Exception(f"Provided ID column {id_col} not found in file {file}") - ids_per_file[file] = list(df[id_col]) - else: - raise Exception(f"Unknown format {format}") + dataframe = pd.DataFrame(np.asarray(hdulist[nhdu].data), columns=names) + ids = list(dataframe[id_col]) + elif format == "csv": + dataframe = pd.read_csv(file) + if id_col not in dataframe.columns: + raise Exception(f"Provided ID column {id_col} not found in file {file}") + ids = list(dataframe[id_col]) + elif format == "parquet": + df = pq.read_table(file).to_pandas() + for name in list(df.columns): + if name.startswith("_"): + df.rename(columns={name: name[1:]}, inplace=True) + if id_col.startswith("_"): + id_col = id_col[1:] + if id_col not in df.columns: + raise Exception(f"Provided ID column {id_col} not found in file {file}") + ids = list(df[id_col]) + else: + raise Exception(f"Unknown format {format}") + + return ids + + +def verify_ids(files: list, id_col: str, format: str, num_proc: int = 4): + ids_per_file = {} + files_copy = deepcopy(files) + + with multiprocessing.Pool(processes=num_proc) as pool: + for result in tqdm( + pool.imap( + get_file_ids, + [(file, id_col, format) for file in files], + ), + total=len(files), + ): + file = files_copy.pop(0) + ids_per_file[file] = result # now we have a list of all the ids in all the files # we want to make sure that all the ids are unique @@ -541,6 +564,8 @@ def verify_ids(files, id_col, format): for file in files: ids += ids_per_file[file] + log(f"in total, we found {len(set(ids))} unique IDs out of {len(ids)} IDs") + if len(ids) != len(set(ids)): # we have duplicate ids # we want to print out the file names concerned, and the ids concerned @@ -558,7 +583,6 @@ def verify_ids(files, id_col, format): log( f"{len(duplicate_ids)} duplicate IDs found. Please make sure that all the IDs are unique across all files before ingesting" ) - log(f"in total, we found {len(set(ids))} unique IDs out of {len(ids)} IDs") raise Exception( "Duplicate IDs found. Please make sure that all the IDs are unique across all files before ingesting" ) @@ -577,6 +601,8 @@ def run( max_docs: int = None, rm: bool = False, format: str = "fits", + verify_only: bool = False, + skip_verify: bool = False, ): """Pre-process and ingest catalog from fits files into Kowalski :param path: path to fits file @@ -611,8 +637,11 @@ def run( for root, dirnames, filenames in os.walk(path): files += [os.path.join(root, f) for f in filenames if f.endswith(format)] - if id_col is not None: - verify_ids(files, id_col, format) + if id_col is not None and not skip_verify: + verify_ids(files, id_col, format, num_proc=num_proc) + + if verify_only: + return input_list = [ (