From f3312ba9e62f78ed26b3b4a1ebbe51e508395d85 Mon Sep 17 00:00:00 2001 From: Elliott Date: Fri, 8 Dec 2023 15:03:21 -0500 Subject: [PATCH] fix: better error messaging for no model outputs (#816) --- dataquality/__init__.py | 2 +- .../loggers/data_logger/text_classification.py | 8 +++++++- dataquality/loggers/data_logger/text_multi_label.py | 13 +++++++++---- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/dataquality/__init__.py b/dataquality/__init__.py index 7e05cb71b..0a265fac4 100644 --- a/dataquality/__init__.py +++ b/dataquality/__init__.py @@ -31,7 +31,7 @@ """ -__version__ = "1.3.8" +__version__ = "1.3.9" import sys from typing import Any, List, Optional diff --git a/dataquality/loggers/data_logger/text_classification.py b/dataquality/loggers/data_logger/text_classification.py index e70f7eea0..23bdd492a 100644 --- a/dataquality/loggers/data_logger/text_classification.py +++ b/dataquality/loggers/data_logger/text_classification.py @@ -538,11 +538,17 @@ def validate_labels(cls) -> None: "See `dataquality.set_labels_for_run`" ) + assert cls.logger_config.observed_num_labels, ( + "There were no observed labels from the model output. Did you " + "log model outputs? Try calling dq.log_model_outputs() or using " + "`watch(trainer)` in your training loop." + ) + assert len(cls.logger_config.labels) == cls.logger_config.observed_num_labels, ( f"You set your labels to be {cls.logger_config.labels} " f"({len(cls.logger_config.labels)} labels) but based on training, your " f"model is expecting {cls.logger_config.observed_num_labels} labels. " - f"Use dataquality.set_labels_for_run to update your config labels." + "Use dataquality.set_labels_for_run to update your config labels." ) assert cls.logger_config.observed_labels.issubset(cls.logger_config.labels), ( diff --git a/dataquality/loggers/data_logger/text_multi_label.py b/dataquality/loggers/data_logger/text_multi_label.py index 9fce26baf..2a89167f2 100644 --- a/dataquality/loggers/data_logger/text_multi_label.py +++ b/dataquality/loggers/data_logger/text_multi_label.py @@ -298,7 +298,7 @@ def validate_labels(cls) -> None: f"task_{i}" for i in range(cls.logger_config.observed_num_tasks) ] warnings.warn( - f"No tasks were set for this run. Setting tasks to " + "No tasks were set for this run. Setting tasks to " f"{cls.logger_config.tasks}" ) @@ -306,14 +306,19 @@ def validate_labels(cls) -> None: f"You set your task names as {cls.logger_config.tasks} " f"({len(cls.logger_config.tasks)} tasks but based on training, your model " f"has {cls.logger_config.observed_num_tasks} " - f"tasks. Use dataquality.set_tasks_for_run to update your config tasks." + "tasks. Use dataquality.set_tasks_for_run to update your config tasks." ) assert len(cls.logger_config.labels) == cls.logger_config.observed_num_tasks, ( f"You set your labels to be {cls.logger_config.labels} " f"({len(cls.logger_config.labels)} tasks) but based on training, your " f"model has {cls.logger_config.observed_num_tasks} tasks. " - f"Use dataquality.set_labels_for_run to update your config labels." + "Use dataquality.set_labels_for_run to update your config labels." + ) + assert cls.logger_config.observed_num_labels is not None, ( + "There were no observed labels from the model output. Did you " + "log model outputs? Try calling dq.log_model_outputs() or using " + "`watch(trainer)` in your training loop." ) assert isinstance(cls.logger_config.observed_num_labels, list), ( f"Is your task_type correct? The observed number of labels is " @@ -325,7 +330,7 @@ def validate_labels(cls) -> None: == cls.logger_config.observed_num_tasks ), ( "Something went wrong with model output logging. Based on training, the " - f"observed number of labels per task is " + "observed number of labels per task is " f"{cls.logger_config.observed_num_labels} indicating " f"{len(cls.logger_config.observed_num_labels)} tasks, but the observed " f"number of tasks is only {cls.logger_config.observed_num_tasks}. Ensure "