From 3a024da162cc5d3fd0b25dcc8d8f731aa1dcf21b Mon Sep 17 00:00:00 2001 From: Thomas Capelle Date: Fri, 20 Dec 2024 14:31:47 +0100 Subject: [PATCH] Union -> | --- weave/scorers/context_relevance_scorer.py | 6 ++++-- weave/scorers/llm_scorer.py | 8 +++++--- weave/scorers/llm_utils.py | 16 ++++------------ weave/scorers/moderation_scorer.py | 2 ++ weave/scorers/robustness_scorer.py | 11 ++++++----- weave/trace_server/requests.py | 8 +++++--- 6 files changed, 26 insertions(+), 25 deletions(-) diff --git a/weave/scorers/context_relevance_scorer.py b/weave/scorers/context_relevance_scorer.py index 9239045c176..65ff1d22856 100644 --- a/weave/scorers/context_relevance_scorer.py +++ b/weave/scorers/context_relevance_scorer.py @@ -1,7 +1,9 @@ import json import os from importlib.util import find_spec -from typing import Any, Optional, Union +from typing import Any, Optional + +from __future__ import annotations import numpy as np from pydantic import PrivateAttr @@ -362,7 +364,7 @@ def score( self, output: str, query: str, - context: Union[str, list[str]], + context: str | list[str], verbose: bool = False, ) -> dict[str, Any]: """Score multiple documents and compute weighted average relevance.""" diff --git a/weave/scorers/llm_scorer.py b/weave/scorers/llm_scorer.py index c597ac7926f..58bc787ab0a 100644 --- a/weave/scorers/llm_scorer.py +++ b/weave/scorers/llm_scorer.py @@ -1,5 +1,7 @@ from importlib.util import find_spec -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional + +from __future__ import annotations from pydantic import Field, field_validator @@ -200,11 +202,11 @@ def tokenize_input(self, prompt: str) -> "Tensor": prompt, return_tensors="pt", truncation=False ).input_ids.to(self.device) - def predict_chunk(self, input_ids: "Tensor") -> list[Union[int, float]]: + def predict_chunk(self, input_ids: "Tensor") -> list[int | float]: raise NotImplementedError("Subclasses must implement predict_chunk method.") def aggregate_predictions( - self, all_predictions: list[list[Union[int, float]]] + self, all_predictions: list[list[int | float]] ) -> list[float]: """ Aggregate predictions using the specified class attribute method. diff --git a/weave/scorers/llm_utils.py b/weave/scorers/llm_utils.py index ad1d910b6ad..e3d17fd642d 100644 --- a/weave/scorers/llm_utils.py +++ b/weave/scorers/llm_utils.py @@ -2,7 +2,8 @@ import inspect import os -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any +from __future__ import annotations OPENAI_DEFAULT_MODEL = "gpt-4o" OPENAI_DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" @@ -27,16 +28,7 @@ from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI from torch import device - _LLM_CLIENTS = Union[ - OpenAI, - AsyncOpenAI, - AzureOpenAI, - AsyncAzureOpenAI, - Anthropic, - AsyncAnthropic, - Mistral, - GenerativeModel, - ] + _LLM_CLIENTS = OpenAI | AsyncOpenAI | AzureOpenAI | AsyncAzureOpenAI | Anthropic | AsyncAnthropic | Mistral | GenerativeModel else: _LLM_CLIENTS = object @@ -94,7 +86,7 @@ def create( def embed( - client: _LLM_CLIENTS, model_id: str, texts: Union[str, list[str]], **kwargs: Any + client: _LLM_CLIENTS, model_id: str, texts: str | list[str], **kwargs: Any ) -> list[list[float]]: client_type = type(client).__name__.lower() if "openai" in client_type: diff --git a/weave/scorers/moderation_scorer.py b/weave/scorers/moderation_scorer.py index 2387a374c04..3ab5d7a4ffe 100644 --- a/weave/scorers/moderation_scorer.py +++ b/weave/scorers/moderation_scorer.py @@ -1,6 +1,8 @@ import os from typing import TYPE_CHECKING, Any, Optional +from __future__ import annotations + from pydantic import PrivateAttr, field_validator import weave diff --git a/weave/scorers/robustness_scorer.py b/weave/scorers/robustness_scorer.py index ce4ac533a36..f3273e90b1a 100644 --- a/weave/scorers/robustness_scorer.py +++ b/weave/scorers/robustness_scorer.py @@ -1,7 +1,8 @@ import math import random import string -from typing import Any, Optional, Union +from typing import Any, Optional +from __future__ import annotations import numpy as np @@ -79,15 +80,15 @@ def model_post_init(self, __context: Any) -> None: @weave.op def score( self, - output: list[Union[str, bool]], - ground_truths: Optional[list[Union[str, bool]]] = None, + output: list[str | bool], + ground_truths: Optional[list[str | bool]] = None, ) -> dict: """ Computes the robustness score of the model's outputs. Args: - output (List[Union[str, bool]]): A list containing the original output followed by perturbed outputs. - ground_truths (Optional[List[Union[str, bool]]]): Optional list of ground truths corresponding to each output. + output (List[str | bool]): A list containing the original output followed by perturbed outputs. + ground_truths (Optional[List[str | bool]]): Optional list of ground truths corresponding to each output. Returns: dict: A dictionary containing the robustness metrics and scores. diff --git a/weave/trace_server/requests.py b/weave/trace_server/requests.py index f9bf2bbd7b9..921f872328b 100644 --- a/weave/trace_server/requests.py +++ b/weave/trace_server/requests.py @@ -5,7 +5,9 @@ import os import threading from time import time -from typing import Any, Optional, Union +from typing import Any, Optional + +from __future__ import annotations from requests import HTTPError as HTTPError from requests import PreparedRequest, Response, Session @@ -34,7 +36,7 @@ THEME_JSON = "ansi_dark" -def decode_str(string: Union[str, bytes]) -> str: +def decode_str(string: str | bytes) -> str: """Decode a bytes object to a string.""" return string if isinstance(string, str) else string.decode("utf-8") @@ -161,7 +163,7 @@ def get(url: str, params: Optional[dict[str, str]] = None, **kwargs: Any) -> Res def post( url: str, - data: Optional[Union[dict[str, Any], str]] = None, + data: Optional[dict[str, Any] | str] = None, json: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Response: