Skip to content

Commit

Permalink
update: docs
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Oct 30, 2024
1 parent ac62655 commit 4e474ea
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
20 changes: 9 additions & 11 deletions docs/metrics/spatial_relationship.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,18 @@ This module aims to implement the Spatial relationship metric described in secti
## Step 2: Evaluate

```python
import wandb
import asyncio
import weave

from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline
from hemm.models import DiffusersModel
from hemm.metrics.spatial_relationship import SpatialRelationshipMetric2D
from hemm.metrics.image_quality import LPIPSMetric, PSNRMetric, SSIMMetric

# Initialize Weave and WandB
wandb.init(project="image-quality-leaderboard", job_type="evaluation")
# Initialize Weave
weave.init(project_name="image-quality-leaderboard")

# Initialize the diffusion model to be evaluated as a `weave.Model` using `BaseWeaveModel`
model = BaseDiffusionModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4")

# Add the model to the evaluation pipeline
evaluation_pipeline = EvaluationPipeline(model=model)
# Initialize the diffusion model to be evaluated as a `weave.Model`
model = DiffusersModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4")

# Define the judge model for 2d spatial relationship metric
judge = DETRSpatialRelationShipJudge(
Expand All @@ -43,10 +40,11 @@ This module aims to implement the Spatial relationship metric described in secti

# Add 2d spatial relationship Metric to the evaluation pipeline
metric = SpatialRelationshipMetric2D(judge=judge, name="2d_spatial_relationship_score")
evaluation_pipeline.add_metric(metric)

# Evaluate!
evaluation_pipeline(dataset="t2i_compbench_spatial_prompts:v0")
dataset = weave.ref("2d-spatial-prompts-mscoco:v0").get()
evaluation = weave.Evaluation(dataset=dataset, scorers=[metric])
summary = asyncio.run(evaluation.evaluate(model))
```

## Metrics
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dev = [
"isort>=5.13.2",
"black>=24.10.0",
"ruff>=0.6.9",
"pytest>=8.3.3",
]

docs = [
Expand Down

0 comments on commit 4e474ea

Please sign in to comment.