diff --git a/hemm/eval_pipelines/__init__.py b/hemm/eval_pipelines/__init__.py deleted file mode 100644 index dea3c26..0000000 --- a/hemm/eval_pipelines/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .eval_pipeline import EvaluationPipeline - -__all__ = ["EvaluationPipeline"] diff --git a/hemm/eval_pipelines/eval_pipeline.py b/hemm/eval_pipelines/eval_pipeline.py deleted file mode 100644 index d031b49..0000000 --- a/hemm/eval_pipelines/eval_pipeline.py +++ /dev/null @@ -1,40 +0,0 @@ -import asyncio -from abc import ABC -from typing import Dict, List, Union - -import weave - - -class EvaluationPipeline(ABC): - """Evaluation pipeline to evaluate the a multi-modal generative model. - - Args: - model (BaseDiffusionModel): The model to evaluate. - """ - - def __init__(self, model: weave.Model) -> None: - super().__init__() - self.model = model - - self.image_size = (self.model.image_height, self.model.image_width) - self.scorers = [] - - def add_metric(self, metric: Union[callable, weave.Scorer]): - """Add a metric function to the evaluation pipeline. - - Args: - metric (BaseMetric): Metric function to evaluate the generated images. - """ - self.scorers.append(metric) - - def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]: - """Evaluate the Stable Diffusion model on the given dataset. - - Args: - dataset (Union[List[Dict], str]): Dataset to evaluate the model on. If a string is - passed, it is assumed to be a Weave dataset reference. - """ - dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset - evaluation = weave.Evaluation(dataset=dataset, scorers=self.scorers) - summary = asyncio.run(evaluation.evaluate()) - return summary