-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unit tests for BiasScorer and ToxicityScorer
- Introduced unit tests for `BiasScorer` and `ToxicityScorer` to validate their scoring functionalities. - Implemented tests for handling basic strings, large inputs, and threshold logic to ensure robustness. - Added fixtures to initialize scorer instances with tiny downloaded models for testing. - Enhanced test coverage for both scorers, ensuring they correctly flag biased and toxic content based on defined thresholds.
- Loading branch information
Showing
3 changed files
with
447 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pytest | ||
|
||
from tests.scorers.test_utils import TINY_MODEL_PATHS, generate_large_text | ||
from weave.scorers.llm_utils import download_model | ||
from weave.scorers.moderation_scorer import BiasScorer | ||
|
||
|
||
@pytest.fixture | ||
def bias_scorer(): | ||
"""Fixture that returns a BiasScorer instance using a tiny downloaded model.""" | ||
tiny_model_path = download_model(TINY_MODEL_PATHS["bias_scorer"]) | ||
return BiasScorer( | ||
model_name_or_path=tiny_model_path, | ||
name="test-bias", | ||
description="Tiny test bias scorer", | ||
device="cpu", # Use CPU for testing | ||
threshold=0.5, | ||
) | ||
|
||
|
||
def test_bias_scorer_simple(bias_scorer): | ||
"""Tests that the scorer can handle a basic string.""" | ||
output = "This is a balanced statement with no bias." | ||
result = bias_scorer.score(output) | ||
|
||
assert isinstance(result, dict), "Result should be a dictionary." | ||
assert "flagged" in result, "Result must contain the 'flagged' key." | ||
assert "extras" in result, "Result must contain the 'extras' dictionary." | ||
assert isinstance(result["flagged"], bool), "'flagged' must be a boolean." | ||
assert isinstance(result["extras"], dict), "'extras' must be a dictionary." | ||
for category in bias_scorer._categories: | ||
assert ( | ||
category in result["extras"] | ||
), f"'{category}' should be in 'extras' dictionary." | ||
|
||
|
||
def test_bias_scorer_large_input(bias_scorer): | ||
"""Tests bias scorer with very large input to ensure it doesn't crash.""" | ||
large_text = generate_large_text(100_000) # 100k characters | ||
result = bias_scorer.score(large_text) | ||
|
||
assert isinstance(result, dict) | ||
assert "flagged" in result | ||
assert isinstance(result["flagged"], bool) | ||
assert "extras" in result | ||
for category in bias_scorer._categories: | ||
assert category in result["extras"], f"'{category}' should be in 'extras'." | ||
|
||
|
||
def test_bias_scorer_threshold(bias_scorer): | ||
""" | ||
Confirms that the threshold logic works properly. | ||
This example text is contrived to produce mild results in categories. | ||
""" | ||
text = "This text slightly contains gender bias and maybe a hint of racial bias." | ||
result = bias_scorer.score(text) | ||
|
||
assert isinstance(result, dict), "Result should be a dictionary." | ||
assert "flagged" in result, "Should have a 'flagged' key in result." | ||
assert isinstance(result["flagged"], bool), "'flagged' must be a boolean." | ||
scores = result["extras"] | ||
# If any score is above threshold, it should be flagged | ||
should_flag = any(v is True for v in scores.values()) | ||
assert result["flagged"] == should_flag |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import pytest | ||
|
||
from tests.scorers.test_utils import TINY_MODEL_PATHS, generate_large_text | ||
from weave.scorers.llm_utils import download_model | ||
from weave.scorers.moderation_scorer import ToxicityScorer | ||
|
||
|
||
@pytest.fixture | ||
def toxicity_scorer(): | ||
""" | ||
Fixture that returns a ToxicityScorer instance using a tiny downloaded model, | ||
similar to test_context_relevance_scorer.py logic. | ||
""" | ||
tiny_model_path = download_model(TINY_MODEL_PATHS["toxicity_scorer"]) | ||
return ToxicityScorer( | ||
model_name_or_path=tiny_model_path, | ||
name="test-toxicity", | ||
description="Tiny test toxicity scorer", | ||
device="cpu", # Use CPU for testing | ||
total_threshold=5, | ||
category_threshold=2, | ||
) | ||
|
||
|
||
def test_toxicity_scorer_simple(toxicity_scorer): | ||
"""Tests that the scorer can handle a basic string.""" | ||
output = "Just a simple, harmless sentence." | ||
result = toxicity_scorer.score(output) | ||
|
||
assert isinstance(result, dict), "Result should be a dictionary." | ||
assert "flagged" in result, "Result must contain the 'flagged' key." | ||
assert "extras" in result, "Result must contain the 'extras' dictionary." | ||
assert isinstance(result["flagged"], bool), "'flagged' must be a boolean." | ||
assert isinstance(result["extras"], dict), "'extras' must be a dictionary." | ||
for category in toxicity_scorer._categories: | ||
assert ( | ||
category in result["extras"] | ||
), f"'{category}' should be in 'extras' dictionary." | ||
|
||
|
||
def test_toxicity_scorer_large_input(toxicity_scorer): | ||
"""Tests toxicity scorer with very large input to ensure it doesn't crash.""" | ||
large_text = generate_large_text(100_000) # 100k characters | ||
result = toxicity_scorer.score(large_text) | ||
|
||
assert isinstance(result, dict) | ||
assert "flagged" in result | ||
assert isinstance(result["flagged"], bool) | ||
assert "extras" in result | ||
assert all(cat in result["extras"] for cat in toxicity_scorer._categories) | ||
|
||
|
||
def test_toxicity_scorer_threshold(toxicity_scorer): | ||
""" | ||
Confirms that the total threshold logic and category threshold logic both work. | ||
This example text is contrived to produce mild results in multiple categories. | ||
""" | ||
text = "This text slightly offends many groups just a little bit." | ||
result = toxicity_scorer.score(text) | ||
|
||
assert isinstance(result, dict) | ||
assert "flagged" in result | ||
assert isinstance(result["flagged"], bool) | ||
assert "extras" in result | ||
total_score = sum(result["extras"].values()) | ||
highest_cat_score = max(result["extras"].values()) | ||
should_flag = (total_score >= toxicity_scorer.total_threshold) or ( | ||
highest_cat_score >= toxicity_scorer.category_threshold | ||
) | ||
assert result["flagged"] == should_flag | ||
|
||
|
||
# def test_toxicity_scorer_clean(toxicity_scorer): | ||
# """ | ||
# Tests behavior on benign content that should not trigger any category flags. | ||
# """ | ||
# clean_text = "I really enjoy friendly and respectful conversations." | ||
# result = toxicity_scorer.score(clean_text) | ||
|
||
# assert isinstance(result, dict) | ||
# assert "extras" in result | ||
# assert all(cat in result["extras"] for cat in toxicity_scorer._categories) | ||
# assert not result["flagged"], "Clean text should not be flagged." |
Oops, something went wrong.