Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
tcapelle committed Dec 20, 2024
1 parent a9d036a commit 29bf188
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 23 deletions.
4 changes: 2 additions & 2 deletions weave/scorers/context_relevance_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class ContextRelevanceScorer(HuggingFaceScorer):
threshold: float = 0.7
model_max_length: int = 1280

def load_model(self):
def load_model(self) -> None:
try:
if find_spec("torch") is None:
raise ImportError("torch is required but not installed")
Expand All @@ -285,7 +285,7 @@ def load_model(self):
)
self.model.eval()

def load_tokenizer(self):
def load_tokenizer(self) -> None:
try:
from transformers import AutoTokenizer
except ImportError:
Expand Down
42 changes: 25 additions & 17 deletions weave/scorers/hallucination_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
OPENAI_DEFAULT_MODEL,
create,
download_model,
set_device,
)
from weave.scorers.utils import stringify

Expand Down Expand Up @@ -87,7 +86,7 @@
"""


def get_chat_template_messages(query: str, output: str, context: str = None):
def get_chat_template_messages(query: str, output: str, context: Optional[str] = None):
system_prompt = """The task is to evaluate whether the <output> contains \
information not supported by the <query> or <context>, or \
whether the <output> contradicts the information provided in the <query> or <context>.
Expand Down Expand Up @@ -261,7 +260,7 @@ class HallucinationScorer(HuggingFaceScorer):
_local_model_path: str = None
import_failed: bool = False

def model_post_init(self, __context) -> None:
def load_model(self) -> None:
if self.base_url:
print(f"Using external API at {self.base_url} for scoring.")
return # Skip local model loading if base_url is provided
Expand All @@ -272,17 +271,14 @@ def model_post_init(self, __context) -> None:
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
)
except ImportError:
self.import_failed = True
print(
"The `transformers` package is required to use the HallucinationScorer, please run `pip install transformers`"
f"The `transformers` package is required to use the {self.__class__.__name__}, please run `pip install transformers`"
)
return

self.device = set_device(self.device)

if self.model is None:
# Check if the model is already downloaded
if os.path.isdir(self.model_name_or_path):
Expand All @@ -303,28 +299,42 @@ def model_post_init(self, __context) -> None:
self._local_model_path,
torch_dtype="bfloat16",
trust_remote_code=True,
).to(self.device)
device_map=self.device,
)
self._tokenizer = self.model.tokenzier
self._tokenizer.model_max_length = self.model_max_length
else:
self.model = AutoModelForCausalLM.from_pretrained(
self._local_model_path, torch_dtype="bfloat16"
).to(self.device)
self._local_model_path,
torch_dtype="bfloat16",
device_map=self.device,
)

if self.use_torch_compile:
self.model.generation_config.cache_implementation = "static"
self.model = torch.compile(
self.model, backend="inductor", fullgraph=True
)
if not self.do_sample:
self.top_k = None
self.top_p = None
self.temperature = None
self.model.eval()

def load_tokenizer(self) -> None:
try:
from transformers import AutoTokenizer
except ImportError:
self.import_failed = True
print(
f"The `transformers` package is required to use the {self.__class__.__name__}, please run `pip install transformers`"
)
return
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(
self._local_model_path, model_max_length=self.model_max_length
)
if not self.do_sample:
self.top_k = None
self.top_p = None
self.temperature = None
print(f"Tokenizer loaded on {self.device}")

def _score_via_api(self, messages: list[dict[str, str]]) -> dict[str, Any]:
import requests
Expand Down Expand Up @@ -397,9 +407,7 @@ def score(self, query: str, context: str, output: str) -> dict[str, Any]:

pad_token_id = self._tokenizer.eos_token_id

with torch.no_grad():
self.model.eval()

with torch.inference_mode():
res = self.model.generate(
inp_tokenized["input_ids"],
max_new_tokens=self.max_new_tokens,
Expand Down
8 changes: 6 additions & 2 deletions weave/scorers/llm_scorer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Optional, Union

from pydantic import Field, field_validator
Expand Down Expand Up @@ -114,7 +115,10 @@ class HuggingFacePipelineScorer(Scorer):
def model_post_init(self, __context: Any) -> None:
self.device = set_device(self.device)
try:
from transformers import pipeline
if find_spec("transformers") is None:
print(
"The `transformers` package is required to use PipelineScorer, please run `pip install transformers`"
)
except ImportError:
print(
"The `transformers` package is required to use PipelineScorer, please run `pip install transformers`"
Expand Down Expand Up @@ -198,7 +202,7 @@ def tokenize_input(self, prompt: str) -> "Tensor":
prompt, return_tensors="pt", truncation=False
).input_ids.to(self.device)

def predict_chunk(self, input_ids: "Tensor") -> list[int]:
def predict_chunk(self, input_ids: "Tensor") -> list[Union[int, float]]:
raise NotImplementedError("Subclasses must implement predict_chunk method.")

def aggregate_predictions(
Expand Down
4 changes: 2 additions & 2 deletions weave/scorers/moderation_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def load_model(self):
print(f"Using external API at {self.base_url} for scoring.")
return # Skip local model loading if base_url is provided
try:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import AutoModelForSequenceClassification
except ImportError:
print(
"The `transformers` package is required to use {self.__class__.__name__}, please run `pip install transformers`"
Expand All @@ -160,7 +160,7 @@ def load_tokenizer(self):
print(f"Using external API at {self.base_url} for scoring.")
return # Skip local model loading if base_url is provided
try:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import AutoTokenizer
except ImportError:
print(
"The `transformers` package is required to use {self.__class__.__name__}, please run `pip install transformers`"
Expand Down

0 comments on commit 29bf188

Please sign in to comment.