Skip to content

Commit

Permalink
more MyPy
Browse files Browse the repository at this point in the history
  • Loading branch information
tcapelle committed Dec 20, 2024
1 parent d5bcc4c commit c15033c
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 44 deletions.
6 changes: 2 additions & 4 deletions weave/scorers/accuracy_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class AccuracyScorer(Scorer):
)

@weave.op
def score(self, output: Any, ground_truth: Any) -> float:
def score(self, output: Any, ground_truth: Any) -> dict[str, Any]:
"""
Compare a single prediction to the ground truth and return a binary correctness score.
Expand Down Expand Up @@ -95,7 +95,7 @@ def summarize(self, score_rows: list[dict]) -> Optional[dict]:

def _summarize_multiclass(
self, scores: list[float], outputs: list[Any], ground_truths: list[Any]
) -> dict:
) -> dict[str, Any]:
"""
Summarize accuracy for multiclass tasks.
Expand Down Expand Up @@ -134,8 +134,6 @@ def _summarize_multiclass(
accuracy = sum(
acc * weight for acc, weight in zip(per_class_accuracy, weights)
)
elif self.average == "none":
accuracy = per_class_accuracy
else:
raise ValueError(f"Unsupported average type: {self.average}")

Expand Down
2 changes: 2 additions & 0 deletions weave/scorers/coherence_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def set_pipeline(self) -> None:
@weave.op
def score_messages(self, prompt: str, output: str) -> dict[str, Any]:
"""Score a prompt response pair."""
assert self.pipeline is not None
coherence_output = self.pipeline(inputs={"text": prompt, "text_pair": output})
flagged = False
if "incoherent" in coherence_output["label"].lower():
Expand Down Expand Up @@ -88,6 +89,7 @@ def _score_via_api(
) -> dict[str, Any]:
import requests

assert self.base_url is not None
response = requests.post(
self.base_url,
json={
Expand Down
24 changes: 13 additions & 11 deletions weave/scorers/context_relevance_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class OldRelevanceScorer(Scorer):
device: The device to use for inference. Defaults to `auto`, which will use `cuda` if available.
"""

model_name_or_path: str = None
model_name_or_path: str = ""
base_url: Optional[str] = None
device: str = "auto"
_classifier: Any = PrivateAttr()
Expand Down Expand Up @@ -159,24 +159,24 @@ def _format_messages(
chat_history = chat_history if isinstance(chat_history, list) else []
context = context if isinstance(context, list) else []
if context:
context = "\n".join(context).strip()
context = f"<documents>\n{context}\n</documents>"
joined_context = "\n".join(context).strip()
joined_context = f"<documents>\n{joined_context}\n</documents>"
else:
context = ""
prompt = f"{context}\n\n{prompt}".strip()
joined_context = ""
prompt = f"{joined_context}\n\n{prompt}".strip()

messages = chat_history + [{"role": "user", "content": prompt}]

messages = [
f"<|msg_start|>{message['role']}\n{message['content']}<|msg_end|>"
f"<|msg_start|>{message['role']}\n{message['content']}<|msg_end|>" # type: ignore
for message in messages
]
messages = "\n".join(messages)
joined_messages = "\n".join(messages) # type: ignore

context = f"<context>{messages}</context>\n"
final_context = f"<context>{joined_messages}</context>\n"
completion = f"<completion>{completion}</completion>\n"

context_and_completion = context + completion
context_and_completion = final_context + completion

return [
{"role": "system", "content": self._system_prompt},
Expand All @@ -192,6 +192,7 @@ def _score_via_api(
) -> dict[str, Any]:
import requests

assert self.base_url is not None
response = requests.post(
self.base_url,
json={
Expand Down Expand Up @@ -386,7 +387,8 @@ def score(

final_score = total_weighted_score / total_length if total_length > 0 else 0.0
res = {"flagged": final_score > self.threshold}
res["extras"] = {"score": final_score}
extras = {"score": final_score}
if verbose:
res["extras"]["all_spans"] = all_spans
extras["all_spans"] = all_spans # type: ignore
res["extras"] = extras # type: ignore
return res
24 changes: 14 additions & 10 deletions weave/scorers/hallucination_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@
"""


def get_chat_template_messages(query: str, output: str, context: Optional[str] = None):
def get_chat_template_messages(
query: str, output: str, context: Optional[str] = None
) -> list[dict[str, str]]:
system_prompt = """The task is to evaluate whether the <output> contains \
information not supported by the <query> or <context>, or \
whether the <output> contradicts the information provided in the <query> or <context>.
Expand Down Expand Up @@ -250,14 +252,14 @@ class HallucinationScorer(HuggingFaceScorer):
max_new_tokens: int = 2
model_max_length: int = 8192
do_sample: bool = False
temperature: float = 0.0
temperature: Optional[float] = 0.0
num_beams: int = 1
top_k: int = 20
top_p: float = 0.7
top_k: Optional[int] = 20
top_p: Optional[float] = 0.7
use_torch_compile: bool = False
use_hhem: bool = True
hhem_score_threshold: float = 0.5
_local_model_path: str = None
_local_model_path: str = ""
import_failed: bool = False

def load_model(self) -> None:
Expand Down Expand Up @@ -339,6 +341,7 @@ def load_tokenizer(self) -> None:
def _score_via_api(self, messages: list[dict[str, str]]) -> dict[str, Any]:
import requests

assert self.base_url is not None
response = requests.post(self.base_url, json={"messages": messages})
response.raise_for_status()
return response.json()
Expand Down Expand Up @@ -423,13 +426,14 @@ def score(self, query: str, context: str, output: str) -> dict[str, Any]:
true_token = 2787
false_token = 4245

input_length = inp_tokenized["input_ids"].shape[1]
completion_tokens = res[0][input_length:].tolist()
input_length = inp_tokenized["input_ids"].shape[1] # type: ignore
completion_tokens = res[0][input_length:].tolist() # type: ignore

is_hallucination = true_token in completion_tokens
extras: dict[str, Any] = {"score": 1 if is_hallucination else 0}
result = {
"flagged": is_hallucination,
"extras": {"score": 1 if is_hallucination else 0},
"extras": extras,
}

if self.debug:
Expand All @@ -443,11 +447,11 @@ def score(self, query: str, context: str, output: str) -> dict[str, Any]:
completion = self._tokenizer.decode(completion_tokens)
print(f"COMPLETION:\n{completion}\n----------------------\n")

result["extras"].update(
extras.update(
{
"completion": completion,
"completion_tokens": completion_tokens,
"total_tokens": len(res[0]),
"total_tokens": len(res[0]), # type: ignore
"total_completion_tokens": len(completion_tokens),
"scorer_worked": scorer_worked,
}
Expand Down
8 changes: 3 additions & 5 deletions weave/scorers/llm_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,6 @@ def load_pipeline(self) -> None:
"Subclasses must implement the `load_pipeline` method."
)

def pipe(self, prompt: str) -> list[dict[str, Any]]:
return self.pipeline(prompt)[0]

@weave.op
def score(self, *, output: Any, **kwargs: Any) -> Any:
raise NotImplementedError
Expand All @@ -144,8 +141,8 @@ class HuggingFaceScorer(Scorer):

model_name_or_path: str = Field(default="", description="The path to the model")
device: str = Field(default="auto", description="The device to use for the model")
model: Optional[Any] = None
tokenizer: Optional[Any] = None
model: Any = None
tokenizer: Any = None

def model_post_init(self, __context: Any = None) -> None:
"""Template method for post-initialization."""
Expand Down Expand Up @@ -198,6 +195,7 @@ def tokenize_input(self, prompt: str) -> "Tensor":
Returns:
A tensor of tokenized input IDs.
"""
assert self.tokenizer is not None
return self.tokenizer(
prompt, return_tensors="pt", truncation=False
).input_ids.to(self.device)
Expand Down
16 changes: 7 additions & 9 deletions weave/scorers/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@

import inspect
import os
from typing import TYPE_CHECKING, Any, Optional, Union

try:
import instructor
except ImportError:
instructor = None
from typing import TYPE_CHECKING, Any, Union

OPENAI_DEFAULT_MODEL = "gpt-4o"
OPENAI_DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
Expand All @@ -24,6 +19,7 @@
LOCAL_MODEL_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "weave_models")

if TYPE_CHECKING:
import instructor.client
from anthropic import Anthropic, AsyncAnthropic
from google.generativeai import GenerativeModel
from instructor.patch import InstructorChatCompletionCreate
Expand Down Expand Up @@ -57,7 +53,9 @@


def instructor_client(client: _LLM_CLIENTS) -> instructor.client:
if instructor is None:
try:
import instructor
except ImportError:
raise ImportError(
"The `instructor` package is required to use LLM-powered scorers, please run `pip install instructor`"
)
Expand Down Expand Up @@ -113,7 +111,7 @@ def embed(
raise ValueError(f"Unsupported client type: {type(client).__name__.lower()}")


def set_device(device: Optional[str] = None) -> device:
def set_device(device: str = "auto") -> device:
"""Set the device to use for the model.
Args:
Expand All @@ -128,7 +126,7 @@ def set_device(device: Optional[str] = None) -> device:
if not cuda_available and "cuda" in device:
# could be `cuda:0`, `cuda:1`, etc.
raise ValueError("CUDA is not available")
if device == "auto" or device is None:
if device == "auto":
if cuda_available:
device = "cuda"
elif torch.backends.mps.is_available():
Expand Down
12 changes: 7 additions & 5 deletions weave/scorers/moderation_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class ToxicityScorer(RollingWindowScorer):
]
)

def load_model(self):
def load_model(self) -> None:
if self.base_url:
print(f"Using external API at {self.base_url} for scoring.")
return # Skip local model loading if base_url is provided
Expand All @@ -155,7 +155,7 @@ def load_model(self):
)
self.model.eval()

def load_tokenizer(self):
def load_tokenizer(self) -> None:
if self.base_url:
print(f"Using external API at {self.base_url} for scoring.")
return # Skip local model loading if base_url is provided
Expand All @@ -168,7 +168,7 @@ def load_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained(self._local_model_path)
print(f"Model and tokenizer loaded on {self.device}")

def predict_chunk(self, input_ids: "Tensor") -> list[int]:
def predict_chunk(self, input_ids: "Tensor") -> list[int | float]:
"""
Predict toxicity scores for a chunk of tokenized input.
Expand All @@ -191,6 +191,7 @@ def predict_chunk(self, input_ids: "Tensor") -> list[int]:
def _score_via_api(self, output: str) -> dict[str, Any]:
import requests

assert self.base_url is not None
response = requests.post(self.base_url, json={"output": output})
response.raise_for_status()
return response.json()
Expand Down Expand Up @@ -253,7 +254,7 @@ class BiasScorer(RollingWindowScorer):
]
)

def load_model(self):
def load_model(self) -> None:
if self.base_url:
print(f"Using external API at {self.base_url} for scoring.")
return # Skip local model loading if base_url is provided
Expand All @@ -275,7 +276,7 @@ def load_model(self):
)
self.model.eval()

def load_tokenizer(self):
def load_tokenizer(self) -> None:
if self.base_url:
print(f"Using external API at {self.base_url} for scoring.")
return # Skip local model loading if base_url is provided
Expand All @@ -300,6 +301,7 @@ def predict_chunk(self, input_ids: "Tensor") -> list[float]:
def _score_via_api(self, output: str, verbose: bool = False) -> dict[str, Any]:
import requests

assert self.base_url is not None
response = requests.post(
self.base_url,
json={"output": output, "verbose": verbose},
Expand Down

0 comments on commit c15033c

Please sign in to comment.