Skip to content

Commit

Permalink
fix: better error messaging for no model outputs (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
Elliott authored Dec 8, 2023
1 parent aa1c312 commit f3312ba
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


__version__ = "1.3.8"
__version__ = "1.3.9"

import sys
from typing import Any, List, Optional
Expand Down
8 changes: 7 additions & 1 deletion dataquality/loggers/data_logger/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,17 @@ def validate_labels(cls) -> None:
"See `dataquality.set_labels_for_run`"
)

assert cls.logger_config.observed_num_labels, (
"There were no observed labels from the model output. Did you "
"log model outputs? Try calling dq.log_model_outputs() or using "
"`watch(trainer)` in your training loop."
)

assert len(cls.logger_config.labels) == cls.logger_config.observed_num_labels, (
f"You set your labels to be {cls.logger_config.labels} "
f"({len(cls.logger_config.labels)} labels) but based on training, your "
f"model is expecting {cls.logger_config.observed_num_labels} labels. "
f"Use dataquality.set_labels_for_run to update your config labels."
"Use dataquality.set_labels_for_run to update your config labels."
)

assert cls.logger_config.observed_labels.issubset(cls.logger_config.labels), (
Expand Down
13 changes: 9 additions & 4 deletions dataquality/loggers/data_logger/text_multi_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,22 +298,27 @@ def validate_labels(cls) -> None:
f"task_{i}" for i in range(cls.logger_config.observed_num_tasks)
]
warnings.warn(
f"No tasks were set for this run. Setting tasks to "
"No tasks were set for this run. Setting tasks to "
f"{cls.logger_config.tasks}"
)

assert len(cls.logger_config.tasks) == cls.logger_config.observed_num_tasks, (
f"You set your task names as {cls.logger_config.tasks} "
f"({len(cls.logger_config.tasks)} tasks but based on training, your model "
f"has {cls.logger_config.observed_num_tasks} "
f"tasks. Use dataquality.set_tasks_for_run to update your config tasks."
"tasks. Use dataquality.set_tasks_for_run to update your config tasks."
)

assert len(cls.logger_config.labels) == cls.logger_config.observed_num_tasks, (
f"You set your labels to be {cls.logger_config.labels} "
f"({len(cls.logger_config.labels)} tasks) but based on training, your "
f"model has {cls.logger_config.observed_num_tasks} tasks. "
f"Use dataquality.set_labels_for_run to update your config labels."
"Use dataquality.set_labels_for_run to update your config labels."
)
assert cls.logger_config.observed_num_labels is not None, (
"There were no observed labels from the model output. Did you "
"log model outputs? Try calling dq.log_model_outputs() or using "
"`watch(trainer)` in your training loop."
)
assert isinstance(cls.logger_config.observed_num_labels, list), (
f"Is your task_type correct? The observed number of labels is "
Expand All @@ -325,7 +330,7 @@ def validate_labels(cls) -> None:
== cls.logger_config.observed_num_tasks
), (
"Something went wrong with model output logging. Based on training, the "
f"observed number of labels per task is "
"observed number of labels per task is "
f"{cls.logger_config.observed_num_labels} indicating "
f"{len(cls.logger_config.observed_num_labels)} tasks, but the observed "
f"number of tasks is only {cls.logger_config.observed_num_tasks}. Ensure "
Expand Down

0 comments on commit f3312ba

Please sign in to comment.