Skip to content

Commit

Permalink
Catalog ingestion: verify id uniqueness in parallel for speed (#236)
Browse files Browse the repository at this point in the history
* add warning when fetching matchfiles list, verify that ids are unique when ingesting catalogs

* recreate mongodb connection

* read files in parallel to find all ids quicker
  • Loading branch information
Theodlz authored Aug 17, 2023
1 parent 282558a commit 62c9c3a
Showing 1 changed file with 61 additions and 32 deletions.
93 changes: 61 additions & 32 deletions kowalski/ingesters/ingest_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import random
import time
import traceback
from copy import deepcopy
from typing import Sequence

import fire
Expand Down Expand Up @@ -337,6 +338,7 @@ def process_file(argument_list: Sequence):
if dec_col not in names:
log(f"Provided DEC column {dec_col} not found")
return

batch = []

def convert_nparray_to_list(value):
Expand Down Expand Up @@ -404,7 +406,6 @@ def convert_nparray_to_list(value):
"coordinates": _radec_geojson,
}
batch.append(document)
total_good_documents += 1
except Exception as exception:
total_bad_documents += 1
log(str(exception))
Expand All @@ -426,12 +427,14 @@ def convert_nparray_to_list(value):
{"_id": {"$in": ids}}
)
if count == len(batch):
total_good_documents += len(batch)
break
n_retries += 1
time.sleep(6)
mongo.close()
mongo = get_mongo_client()
else:
total_good_documents += len(batch)
break

if n_retries == 10:
Expand All @@ -456,12 +459,14 @@ def convert_nparray_to_list(value):
{"_id": {"$in": ids}}
)
if count == len(batch):
total_good_documents += len(batch)
break
n_retries += 1
time.sleep(6)
mongo.close()
mongo = get_mongo_client()
else:
total_good_documents += len(batch)
break

if n_retries == 10:
Expand Down Expand Up @@ -501,36 +506,54 @@ def convert_nparray_to_list(value):
return total_good_documents, total_bad_documents


def verify_ids(files, id_col, format):
ids_per_file = {}
def get_file_ids(argument_list: Sequence):
file, id_col, format = argument_list

for file in files:
if format == "fits":
with fits.open(file, cache=False) as hdulist:
nhdu = 1
names = hdulist[nhdu].columns.names
# first check if the id_col is in the names
if id_col not in names:
raise Exception(
f"Provided ID column {id_col} not found in file {file}"
)
dataframe = pd.DataFrame(np.asarray(hdulist[nhdu].data), columns=names)
ids_per_file[file] = list(dataframe[id_col])
elif format == "csv":
dataframe = pd.read_csv(file)
if id_col not in dataframe.columns:
raise Exception(f"Provided ID column {id_col} not found in file {file}")
ids_per_file[file] = list(dataframe[id_col])
elif format == "parquet":
df = pq.read_table(file).to_pandas()
for name in list(df.columns):
if name.startswith("_"):
df.rename(columns={name: name[1:]}, inplace=True)
if id_col not in df.columns:
ids = []
if format == "fits":
with fits.open(file, cache=False) as hdulist:
nhdu = 1
names = hdulist[nhdu].columns.names
# first check if the id_col is in the names
if id_col not in names:
raise Exception(f"Provided ID column {id_col} not found in file {file}")
ids_per_file[file] = list(df[id_col])
else:
raise Exception(f"Unknown format {format}")
dataframe = pd.DataFrame(np.asarray(hdulist[nhdu].data), columns=names)
ids = list(dataframe[id_col])
elif format == "csv":
dataframe = pd.read_csv(file)
if id_col not in dataframe.columns:
raise Exception(f"Provided ID column {id_col} not found in file {file}")
ids = list(dataframe[id_col])
elif format == "parquet":
df = pq.read_table(file).to_pandas()
for name in list(df.columns):
if name.startswith("_"):
df.rename(columns={name: name[1:]}, inplace=True)
if id_col.startswith("_"):
id_col = id_col[1:]
if id_col not in df.columns:
raise Exception(f"Provided ID column {id_col} not found in file {file}")
ids = list(df[id_col])
else:
raise Exception(f"Unknown format {format}")

return ids


def verify_ids(files: list, id_col: str, format: str, num_proc: int = 4):
ids_per_file = {}
files_copy = deepcopy(files)

with multiprocessing.Pool(processes=num_proc) as pool:
for result in tqdm(
pool.imap(
get_file_ids,
[(file, id_col, format) for file in files],
),
total=len(files),
):
file = files_copy.pop(0)
ids_per_file[file] = result

# now we have a list of all the ids in all the files
# we want to make sure that all the ids are unique
Expand All @@ -541,6 +564,8 @@ def verify_ids(files, id_col, format):
for file in files:
ids += ids_per_file[file]

log(f"in total, we found {len(set(ids))} unique IDs out of {len(ids)} IDs")

if len(ids) != len(set(ids)):
# we have duplicate ids
# we want to print out the file names concerned, and the ids concerned
Expand All @@ -558,7 +583,6 @@ def verify_ids(files, id_col, format):
log(
f"{len(duplicate_ids)} duplicate IDs found. Please make sure that all the IDs are unique across all files before ingesting"
)
log(f"in total, we found {len(set(ids))} unique IDs out of {len(ids)} IDs")
raise Exception(
"Duplicate IDs found. Please make sure that all the IDs are unique across all files before ingesting"
)
Expand All @@ -577,6 +601,8 @@ def run(
max_docs: int = None,
rm: bool = False,
format: str = "fits",
verify_only: bool = False,
skip_verify: bool = False,
):
"""Pre-process and ingest catalog from fits files into Kowalski
:param path: path to fits file
Expand Down Expand Up @@ -611,8 +637,11 @@ def run(
for root, dirnames, filenames in os.walk(path):
files += [os.path.join(root, f) for f in filenames if f.endswith(format)]

if id_col is not None:
verify_ids(files, id_col, format)
if id_col is not None and not skip_verify:
verify_ids(files, id_col, format, num_proc=num_proc)

if verify_only:
return

input_list = [
(
Expand Down

0 comments on commit 62c9c3a

Please sign in to comment.