Skip to content

Commit

Permalink
add Verifier class
Browse files Browse the repository at this point in the history
  • Loading branch information
troyraen committed Sep 15, 2024
1 parent dbc259e commit a6d24e4
Showing 1 changed file with 172 additions and 74 deletions.
246 changes: 172 additions & 74 deletions src/hipscat_import/verification/run_verification.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
"""Run pass/fail checks and generate verification report of existing hipscat table."""
"""Run pass/fail tests and generate verification report of existing hipscat table."""

import collections
import dataclasses
import datetime
import re
from pathlib import Path

import hipscat.io.validation
import pandas as pd
import pyarrow.dataset
import re

from hipscat_import.verification.arguments import VerificationArguments


Expand All @@ -13,76 +20,167 @@ def run(args):
if not isinstance(args, VerificationArguments):
raise TypeError("args must be type VerificationArguments")

# implement everything else.
raise NotImplementedError("Verification not yet implemented.")


def _verify_parquet_files(args):
files_ds = pyarrow.dataset.dataset(
args.input_catalog_path,
ignore_prefixes=[
".",
"_",
"catalog_info.json",
"partition_info.csv",
"point_map.fits",
"provenance_info.json",
],
)
schema = pyarrow.dataset.parquet_dataset(f"{args.input_catalog_path}/_common_metadata").schema

schemas_passed = _check_schemas(files_ds, schema)
file_set_passed = _check_file_set(args, files_ds)
statistics_passed = _check_statistics(files_ds, schema.names)
num_rows_passed = _check_num_rows(args, files_ds)

return all([schemas_passed, file_set_passed, statistics_passed, num_rows_passed])


def _check_schemas(files_ds, schema):
# Check schema against _common_metadata
# [TODO] Are there cases where this will fail but the schema is actually valid? Maybe if a column has all nulls?
schemas_passed = all(
[frag.physical_schema.equals(schema, check_metadata=True) for frag in files_ds.get_fragments()]
)
return schemas_passed


def _check_file_set(args, files_ds):
# Check that parquet files on disk == files in _metadata
metadata_ds = pyarrow.dataset.parquet_dataset(f"{args.input_catalog_path}/_metadata")
# Paths in hipscat _metadata have a double slash ("//") after the dataset name. need to get rid of it.
file_set_passed = set(files_ds.files) == set(f.replace("//", "/") for f in metadata_ds.files)
return file_set_passed


def _check_statistics(files_ds, column_names):
# Check that row group stats were written
statistics_passed = all(
[
set(rg.statistics.keys()) == set(column_names)
for frag in files_ds.get_fragments()
for rg in frag.row_groups
verifier = Verifier.from_args(args)
verifier.test_is_valid_catalog()
verifier.test_schemas()
verifier.test_num_rows()
verifier.write_results()

# args.field_distribution_cols # [TODO]
verifier.distributions

return verifier


Result = collections.namedtuple("Result", ["datetime", "passed", "test", "description"])
"""Verification test result."""


def now():
return datetime.datetime.now(datetime.timezone.utc).strftime("%Y/%m/%d %H:%M:%S %Z")


@dataclasses.dataclass
class Verifier:
args: VerificationArguments = dataclasses.field()
"""Arguments to use during verification."""
files_ds: pyarrow.dataset.Dataset = dataclasses.field()
"""Pyarrow dataset, loaded from the actual files on disk."""
metadata_ds: pyarrow.dataset.Dataset = dataclasses.field()
"""Pyarrow dataset, loaded from the _metadata file."""
common_ds: pyarrow.dataset.Dataset = dataclasses.field()
"""Pyarrow dataset, loaded from the _common_metadata file."""
truth_schema: pyarrow.Schema | None = dataclasses.field(default=None)
"""Pyarrow schema to be used as truth. This will be loaded from args.use_schema_file
if provided. Else the catalog's _common_metadata file will be used."""
results: list[Result] = dataclasses.field(default_factory=list)
"""List of results, one for each test that has been done."""
_distributions: pd.DataFrame | None = dataclasses.field(default=None, init=False)

@classmethod
def from_args(cls, args) -> "Verifier":
# load a dataset from the actual files on disk
files_ds = pyarrow.dataset.dataset(
args.input_catalog_path,
ignore_prefixes=[
".",
"_",
"catalog_info.json",
"partition_info.csv",
"point_map.fits",
"provenance_info.json",
],
)

# load a dataset from the _metadata file
metadata_ds = pyarrow.dataset.parquet_dataset(f"{args.input_catalog_path}/_metadata")

# load a dataset from the _common_metadata file
common_ds = pyarrow.dataset.parquet_dataset(f"{args.input_catalog_path}/_common_metadata")

# load the input schema if provided, else use the _common_metadata schema
if args.use_schema_file is not None:
truth_schema = pyarrow.dataset.parquet_dataset(args.use_schema_file).schema
else:
truth_schema = common_ds.schema

return cls(
args=args,
files_ds=files_ds,
metadata_ds=metadata_ds,
common_ds=common_ds,
truth_schema=truth_schema,
)

@property
def distributions(self):
if self._distributions is None:
rowgrp_stats = [rg.statistics for frag in self.files_ds.get_fragments() for rg in frag.row_groups]
dist = pd.json_normalize(rowgrp_stats)

min_ = dist[[f"{c}.min" for c in self.truth_schema.names]].min()
min_ = min_.rename(index={name: name.removesuffix(".min") for name in min_.index})

max_ = dist[[f"{c}.max" for c in self.truth_schema.names]].max()
max_ = max_.rename(index={name: name.removesuffix(".max") for name in max_.index})

self._distributions = pd.DataFrame({"min": min_, "max": max_})
return self._distributions

@property
def resultsdf(self) -> pd.DataFrame:
return pd.DataFrame(self.results)

def test_is_valid_catalog(self) -> bool:
is_valid = hipscat.io.validation.is_valid_catalog(self.args.input_catalog_path, strict=True)
# [FIXME] How to get the hipscat version?
description = "Test if this is a valid HiPSCat catalog using hipscat version <VERSION>."
self._append_result(test_name="is valid catalog", description=description, passed=is_valid)
return is_valid

def test_schemas(self, check_file_metadata: bool = True) -> None:
test_name = "schema"
_inex = "including file metadata" if check_file_metadata else "excluding file metadata"

if self.args.use_schema_file is None:
truth_src = "_common_metadata"
else:
truth_src = "input"
# an input schema was provided as truth, so we need to test _common_metadata against it
passed = self.common_ds.schema.equals(self.truth_schema, check_metadata=check_file_metadata)
description = f"Test that _common_metadata schema equals {truth_src} schema, {_inex}."
self._append_result(passed=passed, test_name=test_name, description=description)

# test schema in file footers
frags_passed = [
frag.physical_schema.equals(self.truth_schema, check_metadata=check_file_metadata)
for frag in self.files_ds.get_fragments()
]
)
return statistics_passed


def _check_num_rows(args, files_ds):
# Check that num rows in each file matches partition_info.csv
partition_cols = ["Norder", "Dir", "Npix"]
part_df = pd.read_csv(f"{args.input_catalog_path}/partition_info.csv").set_index(partition_cols)
files_df = pd.DataFrame(
[
(
int(re.search(r"Norder=(\d+)", frag.path).group(1)),
int(re.search(r"Dir=(\d+)", frag.path).group(1)),
int(re.search(r"Npix=(\d+)", frag.path).group(1)),
frag.metadata.num_rows,
)
for frag in files_ds.get_fragments()
],
columns=["Norder", "Dir", "Npix", "num_rows"],
).set_index(partition_cols)
num_rows_passed = part_df.equals(files_df)
return num_rows_passed
description = f"Test that schema in file footers equals {truth_src} schema, {_inex}."
self._append_result(passed=all(frags_passed), test_name=test_name, description=description)

# test _metadata schema
passed = self.common_ds.schema.equals(self.truth_schema, check_metadata=check_file_metadata)
description = f"Test that _metadata schema equals {truth_src} schema, {_inex}."
self._append_result(passed=passed, test_name=test_name, description=description)

def test_num_rows(self) -> None:
test_name = "num rows"
# get the number of rows in each file, indexed by partition. we treat this as truth.
files_df = self._load_nrows(self.files_ds)

# check _metadata
metadata_df = self._load_nrows(self.metadata_ds)
description = "Test that number of rows in each file matches _metadata file."
self._append_result(passed=metadata_df.equals(files_df), test_name=test_name, description=description)

# check total number of rows
if self.args.expected_total_rows is not None:
passed = self.args.expected_total_rows == files_df.num_rows.sum()
description = "Test that total number of rows matches user-supplied expectation."
self._append_result(passed=passed, test_name=test_name, description=description)

def _load_nrows(self, dataset: pyarrow.dataset.Dataset) -> pd.DataFrame:
partition_cols = ["Norder", "Dir", "Npix"]
nrows_df = pd.DataFrame(
columns=partition_cols + ["num_rows"],
data=[
(
int(re.search(r"Norder=(\d+)", frag.path).group(1)),
int(re.search(r"Dir=(\d+)", frag.path).group(1)),
int(re.search(r"Npix=(\d+)", frag.path).group(1)),
frag.metadata.num_rows,
)
for frag in dataset.get_fragments()
],
)
nrows_df = nrows_df.set_index(partition_cols).sort_index()
return nrows_df

def _append_result(self, *, test_name: str, description: str, passed: bool):
self.results.append(Result(datetime=now(), passed=passed, test=test_name, description=description))

def write_results(self, *, mode: str = "a") -> None:
fout = Path(self.args.output_path) / "verifier_results.csv"
header = False if (mode == "a" and fout.is_file()) else True
self.resultsdf.to_csv(fout, index=False, mode=mode, header=header)

0 comments on commit a6d24e4

Please sign in to comment.