Skip to content

Commit

Permalink
fix: handle dataset table name argument
Browse files Browse the repository at this point in the history
  • Loading branch information
dtria91 committed Dec 13, 2024
1 parent 5cb802a commit 8cf0812
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 7 deletions.
6 changes: 3 additions & 3 deletions spark/jobs/completion_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def compute_metrics(df: DataFrame) -> dict:
complete_record = {}
completion_service = CompletionMetrics()
model_quality = completion_service.extract_metrics(df)
complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality.model_dump(serialize_as_any=True)).decode(
"utf-8"
)
complete_record["MODEL_QUALITY"] = orjson.dumps(
model_quality.model_dump(serialize_as_any=True)
).decode("utf-8")
return complete_record


Expand Down
1 change: 1 addition & 0 deletions spark/jobs/current_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def main(
current_uuid,
reference_dataset_path,
metrics_table_name,
dataset_table_name,
)
except Exception as e:
logging.exception(e)
Expand Down
1 change: 1 addition & 0 deletions spark/jobs/reference_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def main(
reference_dataset_path,
reference_uuid,
metrics_table_name,
dataset_table_name,
)
except Exception as e:
logging.exception(e)
Expand Down
8 changes: 4 additions & 4 deletions spark/tests/completion_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def test_compute_prob(spark_fixture, input_file):

def test_extract_metrics(spark_fixture, input_file):
completion_metrics_service = CompletionMetrics()
completion_metrics_model: CompletionMetricsModel = completion_metrics_service.extract_metrics(input_file)
completion_metrics_model: CompletionMetricsModel = (
completion_metrics_service.extract_metrics(input_file)
)
assert len(completion_metrics_model.tokens) > 0
assert len(completion_metrics_model.mean_per_phrase) > 0
assert len(completion_metrics_model.mean_per_file) > 0
Expand All @@ -40,6 +42,4 @@ def test_extract_metrics(spark_fixture, input_file):
def test_compute_metrics(spark_fixture, input_file):
complete_record = compute_metrics(input_file)
model_quality = complete_record.get("MODEL_QUALITY")
assert model_quality == orjson.dumps(completion_metric_results).decode(
"utf-8"
)
assert model_quality == orjson.dumps(completion_metric_results).decode("utf-8")

0 comments on commit 8cf0812

Please sign in to comment.