diff --git a/spark/jobs/completion_job.py b/spark/jobs/completion_job.py index 93ab9a91..36850e94 100644 --- a/spark/jobs/completion_job.py +++ b/spark/jobs/completion_job.py @@ -17,9 +17,9 @@ def compute_metrics(df: DataFrame) -> dict: complete_record = {} completion_service = CompletionMetrics() model_quality = completion_service.extract_metrics(df) - complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality.model_dump(serialize_as_any=True)).decode( - "utf-8" - ) + complete_record["MODEL_QUALITY"] = orjson.dumps( + model_quality.model_dump(serialize_as_any=True) + ).decode("utf-8") return complete_record diff --git a/spark/jobs/current_job.py b/spark/jobs/current_job.py index ff0140da..4e7eb1ba 100644 --- a/spark/jobs/current_job.py +++ b/spark/jobs/current_job.py @@ -202,6 +202,7 @@ def main( current_uuid, reference_dataset_path, metrics_table_name, + dataset_table_name, ) except Exception as e: logging.exception(e) diff --git a/spark/jobs/reference_job.py b/spark/jobs/reference_job.py index 4750ab3b..8e7017eb 100644 --- a/spark/jobs/reference_job.py +++ b/spark/jobs/reference_job.py @@ -146,6 +146,7 @@ def main( reference_dataset_path, reference_uuid, metrics_table_name, + dataset_table_name, ) except Exception as e: logging.exception(e) diff --git a/spark/tests/completion_metrics_test.py b/spark/tests/completion_metrics_test.py index d4c81c43..19dcb446 100644 --- a/spark/tests/completion_metrics_test.py +++ b/spark/tests/completion_metrics_test.py @@ -31,7 +31,9 @@ def test_compute_prob(spark_fixture, input_file): def test_extract_metrics(spark_fixture, input_file): completion_metrics_service = CompletionMetrics() - completion_metrics_model: CompletionMetricsModel = completion_metrics_service.extract_metrics(input_file) + completion_metrics_model: CompletionMetricsModel = ( + completion_metrics_service.extract_metrics(input_file) + ) assert len(completion_metrics_model.tokens) > 0 assert len(completion_metrics_model.mean_per_phrase) > 0 assert len(completion_metrics_model.mean_per_file) > 0 @@ -40,6 +42,4 @@ def test_extract_metrics(spark_fixture, input_file): def test_compute_metrics(spark_fixture, input_file): complete_record = compute_metrics(input_file) model_quality = complete_record.get("MODEL_QUALITY") - assert model_quality == orjson.dumps(completion_metric_results).decode( - "utf-8" - ) + assert model_quality == orjson.dumps(completion_metric_results).decode("utf-8")