Skip to content

Commit

Permalink
Union -> |
Browse files Browse the repository at this point in the history
  • Loading branch information
tcapelle committed Dec 20, 2024
1 parent c15033c commit 3a024da
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 25 deletions.
6 changes: 4 additions & 2 deletions weave/scorers/context_relevance_scorer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import os
from importlib.util import find_spec
from typing import Any, Optional, Union
from typing import Any, Optional

from __future__ import annotations

import numpy as np
from pydantic import PrivateAttr
Expand Down Expand Up @@ -362,7 +364,7 @@ def score(
self,
output: str,
query: str,
context: Union[str, list[str]],
context: str | list[str],
verbose: bool = False,
) -> dict[str, Any]:
"""Score multiple documents and compute weighted average relevance."""
Expand Down
8 changes: 5 additions & 3 deletions weave/scorers/llm_scorer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional

from __future__ import annotations

from pydantic import Field, field_validator

Expand Down Expand Up @@ -200,11 +202,11 @@ def tokenize_input(self, prompt: str) -> "Tensor":
prompt, return_tensors="pt", truncation=False
).input_ids.to(self.device)

def predict_chunk(self, input_ids: "Tensor") -> list[Union[int, float]]:
def predict_chunk(self, input_ids: "Tensor") -> list[int | float]:
raise NotImplementedError("Subclasses must implement predict_chunk method.")

def aggregate_predictions(
self, all_predictions: list[list[Union[int, float]]]
self, all_predictions: list[list[int | float]]
) -> list[float]:
"""
Aggregate predictions using the specified class attribute method.
Expand Down
16 changes: 4 additions & 12 deletions weave/scorers/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import inspect
import os
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any
from __future__ import annotations

OPENAI_DEFAULT_MODEL = "gpt-4o"
OPENAI_DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
Expand All @@ -27,16 +28,7 @@
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from torch import device

_LLM_CLIENTS = Union[
OpenAI,
AsyncOpenAI,
AzureOpenAI,
AsyncAzureOpenAI,
Anthropic,
AsyncAnthropic,
Mistral,
GenerativeModel,
]
_LLM_CLIENTS = OpenAI | AsyncOpenAI | AzureOpenAI | AsyncAzureOpenAI | Anthropic | AsyncAnthropic | Mistral | GenerativeModel
else:
_LLM_CLIENTS = object

Expand Down Expand Up @@ -94,7 +86,7 @@ def create(


def embed(
client: _LLM_CLIENTS, model_id: str, texts: Union[str, list[str]], **kwargs: Any
client: _LLM_CLIENTS, model_id: str, texts: str | list[str], **kwargs: Any
) -> list[list[float]]:
client_type = type(client).__name__.lower()
if "openai" in client_type:
Expand Down
2 changes: 2 additions & 0 deletions weave/scorers/moderation_scorer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import TYPE_CHECKING, Any, Optional

from __future__ import annotations

from pydantic import PrivateAttr, field_validator

import weave
Expand Down
11 changes: 6 additions & 5 deletions weave/scorers/robustness_scorer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import math
import random
import string
from typing import Any, Optional, Union
from typing import Any, Optional
from __future__ import annotations

import numpy as np

Expand Down Expand Up @@ -79,15 +80,15 @@ def model_post_init(self, __context: Any) -> None:
@weave.op
def score(
self,
output: list[Union[str, bool]],
ground_truths: Optional[list[Union[str, bool]]] = None,
output: list[str | bool],
ground_truths: Optional[list[str | bool]] = None,
) -> dict:
"""
Computes the robustness score of the model's outputs.
Args:
output (List[Union[str, bool]]): A list containing the original output followed by perturbed outputs.
ground_truths (Optional[List[Union[str, bool]]]): Optional list of ground truths corresponding to each output.
output (List[str | bool]): A list containing the original output followed by perturbed outputs.
ground_truths (Optional[List[str | bool]]): Optional list of ground truths corresponding to each output.
Returns:
dict: A dictionary containing the robustness metrics and scores.
Expand Down
8 changes: 5 additions & 3 deletions weave/trace_server/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import os
import threading
from time import time
from typing import Any, Optional, Union
from typing import Any, Optional

from __future__ import annotations

from requests import HTTPError as HTTPError
from requests import PreparedRequest, Response, Session
Expand Down Expand Up @@ -34,7 +36,7 @@
THEME_JSON = "ansi_dark"


def decode_str(string: Union[str, bytes]) -> str:
def decode_str(string: str | bytes) -> str:
"""Decode a bytes object to a string."""
return string if isinstance(string, str) else string.decode("utf-8")

Expand Down Expand Up @@ -161,7 +163,7 @@ def get(url: str, params: Optional[dict[str, str]] = None, **kwargs: Any) -> Res

def post(
url: str,
data: Optional[Union[dict[str, Any], str]] = None,
data: Optional[dict[str, Any] | str] = None,
json: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
Expand Down

0 comments on commit 3a024da

Please sign in to comment.