Skip to content

Commit

Permalink
update: docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Oct 30, 2024
1 parent 4f29ad0 commit 29c44bb
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions hemm/metrics/spatial_relationship/spatial_relationship_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,32 @@ class SpatialRelationshipMetric2D(weave.Scorer):
"""Spatial relationship metric for image generation as proposed in Section 4.2 from the paper
[T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation](https://arxiv.org/pdf/2307.06350).
??? example "Sample usage"
!!! example "Sample usage"
```python
import wandb
import asyncio
import weave
from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline
from hemm.metrics.image_quality import LPIPSMetric, PSNRMetric, SSIMMetric
from hemm.models import DiffusersModel
from hemm.metrics.spatial_relationship.judges import DETRSpatialRelationShipJudge
# Initialize Weave and WandB
wandb.init(project="image-quality-leaderboard", job_type="evaluation")
# Initialize Weave
weave.init(project_name="image-quality-leaderboard")
# Initialize the diffusion model to be evaluated as a `weave.Model` using `BaseWeaveModel`
model = BaseDiffusionModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4")
# Add the model to the evaluation pipeline
evaluation_pipeline = EvaluationPipeline(model=model)
# Initialize the diffusion model to be evaluated as a `weave.Model`
model = DiffusersModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4")
# Define the judge model for 2d spatial relationship metric
judge = DETRSpatialRelationShipJudge(
model_address=detr_model_address, revision=detr_revision
)
# Add 2d spatial relationship Metric to the evaluation pipeline
# Define 2d spatial relationship Metric to the evaluation pipeline
metric = SpatialRelationshipMetric2D(judge=judge, name="2d_spatial_relationship_score")
evaluation_pipeline.add_metric(metric)
# Evaluate!
evaluation_pipeline(dataset="t2i_compbench_spatial_prompts:v0")
dataset = weave.ref("2d-spatial-t2i_compbench_spatial_prompts-mscoco:v0").get()
evaluation = weave.Evaluation(dataset=dataset, scorers=[metric])
asyncio.run(evaluation.evaluate(model))
```
Args:
Expand Down

0 comments on commit 29c44bb

Please sign in to comment.