Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bylabelcount metric type #1484

Merged
merged 5 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/evidently/future/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,6 @@

register_type_alias(Metric, "evidently.future.metrics.classification.ClassificationQualityBase", "evidently:metric_v2:ClassificationQualityBase")
register_type_alias(Metric, "evidently.future.metrics.classification.DummyClassificationQuality", "evidently:metric_v2:DummyClassificationQuality")

register_type_alias(BoundTest, "evidently.future.metric_types.ByLabelCountBoundTest", "evidently:bound_test:ByLabelCountBoundTest")
register_type_alias(Metric, "evidently.future.metric_types.ByLabelCountMetric", "evidently:metric_v2:ByLabelCountMetric")
16 changes: 16 additions & 0 deletions src/evidently/future/backport.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from evidently.future.datasets import DataDefinition
from evidently.future.datasets import Dataset
from evidently.future.metric_types import BoundTest
from evidently.future.metric_types import ByLabelCountValue
from evidently.future.metric_types import ByLabelValue
from evidently.future.metric_types import CountValue
from evidently.future.metric_types import MeanStdValue
Expand Down Expand Up @@ -89,6 +90,15 @@ class Config:
values: Dict[Label, Union[float, int, bool, str]]


class ByLabelCountValueV1(MetricResultV2Adapter):
class Config:
type_alias = "evidently:metric_result:ByLabelCountValueV1"
field_tags = {"values": {IncludeTags.Render}}

counts: Dict[Label, int]
shares: Dict[Label, float]


class CountValueV1(MetricResultV2Adapter):
class Config:
type_alias = "evidently:metric_result:CountValueV1"
Expand Down Expand Up @@ -123,6 +133,12 @@ def metric_result_v2_to_v1(metric_result: MetricResultV2, ignore_widget: bool =
widget=_create_metric_result_widget(metric_result, ignore_widget),
values=metric_result.values,
)
if isinstance(metric_result, ByLabelCountValue):
return ByLabelCountValueV1(
widget=_create_metric_result_widget(metric_result, ignore_widget),
counts=metric_result.counts,
shares=metric_result.shares,
)
if isinstance(metric_result, CountValue):
return CountValueV1(
widget=_create_metric_result_widget(metric_result, ignore_widget),
Expand Down
153 changes: 145 additions & 8 deletions src/evidently/future/metric_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Dict
from typing import Generic
from typing import List
from typing import Literal
from typing import Optional
from typing import Protocol
from typing import Sequence
Expand Down Expand Up @@ -82,6 +83,13 @@ def display_name(self) -> str:
return self._display_name

def to_dict(self):
return {
"id": self._metric.id,
"metric_id": self.explicit_metric_id(),
"value": self.dict(),
}

def explicit_metric_id(self):
config = self.metric.to_metric().dict()
config_items = []
type = None
Expand All @@ -99,11 +107,10 @@ def to_dict(self):
continue
else:
config_items.append(f"{field}={str(value)}")
return {
"id": self._metric.id,
"metric_id": f"{type}({','.join(config_items)})",
"value": self.dict(),
}
if self._metric_value_location is not None:
for key, value in self._metric_value_location.params().items():
config_items.append(f"{key}={value}")
return f"{type}({','.join(config_items)})"

@abc.abstractmethod
def dict(self) -> object:
Expand Down Expand Up @@ -225,6 +232,32 @@ def dict(self) -> object:
return self.values


@dataclasses.dataclass
class ByLabelCountValue(MetricResult):
counts: Dict[Label, int]
shares: Dict[Label, float]

def labels(self) -> List[Label]:
return list(self.counts.keys())

def get_label_result(self, label: Label) -> Tuple[SingleValue, SingleValue]:
count = SingleValue(self.counts[label])
share = SingleValue(self.shares[label])
metric = self.metric
count._metric = metric
share._metric = metric
if not isinstance(metric, ByLabelCountCalculation):
raise ValueError(f"Metric {type(metric)} isn't ByLabelCountCalculation")
count.set_display_name(metric.count_label_display_name(label))
share.set_display_name(metric.share_label_display_name(label))
count._metric_value_location = ByLabelCountValueLocation(metric.to_metric(), label, "count")
share._metric_value_location = ByLabelCountValueLocation(metric.to_metric(), label, "share")
return count, share

def dict(self) -> object:
return {"counts": self.counts, "shares": self.shares}


@dataclasses.dataclass
class CountValue(MetricResult):
count: int
Expand Down Expand Up @@ -304,6 +337,9 @@ class DatasetType(enum.Enum):
class MetricValueLocation:
metric: "Metric"

def params(self) -> Dict[str, str]:
raise NotImplementedError

def value(self, context: "Context", dataset_type: DatasetType) -> SingleValue:
value = self._metric_value_by_dataset(context, dataset_type)
return self.extract_value(value)
Expand Down Expand Up @@ -331,6 +367,9 @@ def extract_value(self, value: MetricResult) -> SingleValue:
)
return value

def params(self) -> Dict[str, str]:
return {}


@dataclasses.dataclass
class ByLabelValueLocation(MetricValueLocation):
Expand All @@ -344,6 +383,30 @@ def extract_value(self, value: MetricResult) -> SingleValue:
)
return value.get_label_result(self.label)

def params(self) -> Dict[str, str]:
return {"label": str(self.label)}


ByLabelCountSlot = Union[Literal["count"], Literal["share"]]


@dataclasses.dataclass
class ByLabelCountValueLocation(MetricValueLocation):
label: Label
slot: ByLabelCountSlot

def extract_value(self, value: MetricResult) -> SingleValue:
if not isinstance(value, ByLabelCountValue):
raise ValueError(
f"Unexpected type of metric result for metric[{str(value.metric)}]:"
f" expected: {ByLabelCountValue.__name__}, actual: {type(value).__name__}"
)
result = value.get_label_result(self.label)
return result[0] if self.slot == "count" else result[1]

def params(self) -> Dict[str, str]:
return {"label": str(self.label), "value_type": str(self.slot)}


@dataclasses.dataclass
class CountValueLocation(MetricValueLocation):
Expand All @@ -357,6 +420,9 @@ def extract_value(self, value: MetricResult) -> SingleValue:
)
return value.get_count() if self.is_count else value.get_share()

def params(self) -> Dict[str, str]:
return {"value_type": "count" if self.is_count else "share"}


@dataclasses.dataclass
class MeanStdValueLocation(MetricValueLocation):
Expand All @@ -370,6 +436,9 @@ def extract_value(self, value: MetricResult) -> SingleValue:
)
return value.get_mean() if self.is_mean else value.get_std()

def params(self) -> Dict[str, str]:
return {"value_type": "mean" if self.is_mean else "std"}


class MetricTestProto(Protocol[TResult]):
def __call__(self, context: "Context", metric: "MetricCalculationBase", value: TResult) -> MetricTestResult: ...
Expand Down Expand Up @@ -426,6 +495,16 @@ def get_default_render_ref(title: str, result: MetricResult, ref_result: MetricR
data=[(k, f"{v:0.3f}", f"{ref_result.values[k]}") for k, v in result.values.items()],
)
]
if isinstance(result, ByLabelCountValue):
assert isinstance(ref_result, ByLabelCountValue)
return [
table_data(
title=title,
size=WidgetSize.FULL,
column_names=["Label", "Current value", "Reference value"],
data=[(k, f"{v:0.3f}", f"{ref_result.counts[k]}") for k, v in result.counts.items()],
)
]
if isinstance(result, CountValue):
assert isinstance(ref_result, CountValue)
return [
Expand Down Expand Up @@ -486,6 +565,14 @@ def get_default_render(title: str, result: TResult) -> List[BaseWidgetInfo]:
data=[(k, f"{v:0.3f}") for k, v in result.values.items()],
)
]
if isinstance(result, ByLabelCountValue):
return [
table_data(
title=title,
column_names=["Label", "Value"],
data=[(k, f"{v:0.3f}") for k, v in result.counts.items()],
)
]
if isinstance(result, CountValue):
return [
counter(
Expand Down Expand Up @@ -599,9 +686,7 @@ def run(self, context: "Context", metric: "MetricCalculationBase", value: Metric
result: MetricTestResult = self.to_test()(context, metric, value)
if result.status == TestStatus.FAIL and not self.is_critical:
result.status = TestStatus.WARNING
metric_conf = metric.to_metric()
column = f" for {metric_conf.column}" if hasattr(metric_conf, "column") else ""
result.description = f"{metric_conf.__class__.__name__}{column}: {result.description}"
result.description = f"{value.explicit_metric_id()}: {result.description}"
return result

def bind_single(self, fingerprint: Fingerprint) -> "BoundTest":
Expand All @@ -613,6 +698,9 @@ def bind_count(self, fingerprint: Fingerprint, is_count: bool) -> "BoundTest":
def bind_by_label(self, fingerprint: Fingerprint, label: Label):
return ByLabelBoundTest(test=self, metric_fingerprint=fingerprint, label=label)

def bind_by_label_count(self, fingerprint: Fingerprint, label: Label, slot: ByLabelCountSlot):
return ByLabelCountBoundTest(test=self, metric_fingerprint=fingerprint, label=label, slot=slot)

def bind_mean_std(self, fingerprint: Fingerprint, is_mean: bool = True):
return MeanStdBoundTest(test=self, metric_fingerprint=fingerprint, is_mean=is_mean)

Expand Down Expand Up @@ -793,6 +881,55 @@ def label_display_name(self, label: Label) -> str:
return self.display_name() + f" for label {label}"


class ByLabelCountBoundTest(BoundTest[ByLabelCountValue]):
label: Label
slot: ByLabelCountSlot

def run_test(
self,
context: "Context",
calculation: MetricCalculationBase,
metric_result: ByLabelCountValue,
) -> MetricTestResult:
value = metric_result.get_label_result(self.label)
return self.test.run(context, calculation, value[0] if self.slot == "count" else value[1])


class ByLabelCountMetric(Metric["ByLabelCountCalculation"]):
tests: Optional[Dict[Label, List[MetricTest]]] = None
share_tests: Optional[Dict[Label, List[MetricTest]]] = None

def get_bound_tests(self, context: "Context") -> List[BoundTest]:
if self.tests is None and self.share_tests is None and context.configuration.include_tests:
return self._get_all_default_tests(context)
fingerprint = self.get_fingerprint()
return [
t.bind_by_label_count(fingerprint, label=label, slot="count")
for label, tests in (self.tests or {}).items()
for t in tests
] + [
t.bind_by_label_count(fingerprint, label=label, slot="share")
for label, tests in (self.share_tests or {}).items()
for t in tests
]


TByLabelCountMetric = TypeVar("TByLabelCountMetric", bound=ByLabelCountMetric)


class ByLabelCountCalculation(
MetricCalculation[ByLabelCountValue, TByLabelCountMetric], Generic[TByLabelCountMetric], ABC
):
def label_metric(self, label: Label) -> SingleValueCalculation:
raise NotImplementedError

def count_label_display_name(self, label: Label) -> str:
raise NotImplementedError

def share_label_display_name(self, label: Label) -> str:
raise NotImplementedError


class CountBoundTest(BoundTest[CountValue]):
is_count: bool

Expand Down
24 changes: 17 additions & 7 deletions src/evidently/future/metrics/column_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from evidently.future.datasets import Dataset
from evidently.future.datasets import DatasetColumn
from evidently.future.metric_types import BoundTest
from evidently.future.metric_types import ByLabelCalculation
from evidently.future.metric_types import ByLabelMetric
from evidently.future.metric_types import ByLabelValue
from evidently.future.metric_types import ByLabelCountCalculation
from evidently.future.metric_types import ByLabelCountMetric
from evidently.future.metric_types import ByLabelCountValue
from evidently.future.metric_types import CountCalculation
from evidently.future.metric_types import CountMetric
from evidently.future.metric_types import CountValue
Expand Down Expand Up @@ -643,11 +643,11 @@ def share_display_name(self) -> str:
return "Share of Drifted Columns"


class UniqueValueCount(ByLabelMetric):
class UniqueValueCount(ByLabelCountMetric):
column: str


class UniqueValueCountCalculation(ByLabelCalculation[UniqueValueCount]):
class UniqueValueCountCalculation(ByLabelCountCalculation[UniqueValueCount]):
def calculate(self, context: "Context", current_data: Dataset, reference_data: Optional[Dataset]):
current_result = self._calculate_value(current_data)
current_result.widget = distribution(
Expand All @@ -664,7 +664,17 @@ def calculate(self, context: "Context", current_data: Dataset, reference_data: O
def display_name(self) -> str:
return "Unique Value Count"

def count_label_display_name(self, label: Label) -> str:
return f"Unique Value Count for label {label}"

def share_label_display_name(self, label: Label) -> str:
return f"Unique Value Share for label {label}"

def _calculate_value(self, dataset: Dataset):
value_counts = dataset.as_dataframe()[self.metric.column].value_counts(dropna=False)
result = ByLabelValue(value_counts.to_dict()) # type: ignore[arg-type]
df = dataset.as_dataframe()
value_counts = df[self.metric.column].value_counts(dropna=False)
total = len(df)

res = value_counts.to_dict()
result = ByLabelCountValue(res, {k: v / total for k, v in res.items()}) # type: ignore[arg-type]
return result
5 changes: 3 additions & 2 deletions src/evidently/future/presets/dataset_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from evidently.core import ColumnType
from evidently.future.container import MetricContainer
from evidently.future.metric_types import ByLabelCountValue
from evidently.future.metric_types import ByLabelMetricTests
from evidently.future.metric_types import ByLabelValue
from evidently.future.metric_types import Metric
Expand Down Expand Up @@ -267,9 +268,9 @@ def _get_metric(
return [convert(context.get_metric_result(metric))]

def _most_common_value(self, unique_value: MetricResult):
if not isinstance(unique_value, ByLabelValue):
if not isinstance(unique_value, ByLabelCountValue):
raise ValueError("Most common value must be of type 'ByLabelValue'")
first = sorted(unique_value.values.items(), key=lambda x: x[1], reverse=True)[0]
first = sorted(unique_value.counts.items(), key=lambda x: x[1], reverse=True)[0]
return f"Label: {first[0]} count: {first[1]}"

def _label_count(
Expand Down
4 changes: 4 additions & 0 deletions src/evidently/metrics/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,3 +842,7 @@
)

register_type_alias(MetricResult, "evidently.future.backport.MeanStdValueV1", "evidently:metric_result:MeanStdValueV1")

register_type_alias(
MetricResult, "evidently.future.backport.ByLabelCountValueV1", "evidently:metric_result:ByLabelCountValueV1"
)
4 changes: 4 additions & 0 deletions src/evidently/pydantic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,10 +490,14 @@ def _list_with_tags(self, current_tags: Set["IncludeTags"]) -> List[Tuple[List[A
if issubclass(self._cls, BaseResult) and self._cls.__config__.extract_as_obj:
return [(self._path, current_tags)]
res = []
from evidently.future.backport import ByLabelCountValueV1
from evidently.future.backport import ByLabelValueV1

if issubclass(self._cls, ByLabelValueV1):
res.append((self._path + ["values"], current_tags.union({IncludeTags.Render})))
if issubclass(self._cls, ByLabelCountValueV1):
res.append((self._path + ["counts"], current_tags.union({IncludeTags.Render})))
res.append((self._path + ["shares"], current_tags.union({IncludeTags.Render})))
for name, field in self._cls.__fields__.items():
field_value = field.type_

Expand Down
Loading
Loading