diff --git a/weave/flow/dataset.py b/weave/flow/dataset.py index d5f6eb5569d8..59a21fb573db 100644 --- a/weave/flow/dataset.py +++ b/weave/flow/dataset.py @@ -47,8 +47,7 @@ class Dataset(Object): @classmethod def from_obj(cls, obj: WeaveObject) -> Self: - rows = obj.rows - return cls(rows=rows) + return cls(rows=obj.rows) @field_validator("rows", mode="before") def convert_to_table(cls, rows: Any) -> weave.Table: diff --git a/weave/flow/eval.py b/weave/flow/eval.py index 0dd50926b843..731458d676df 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -123,6 +123,7 @@ def from_obj(cls, obj: WeaveObject) -> Self: dataset=obj.dataset, scorers=obj.scorers, preprocess_model_input=obj.preprocess_model_input, + trials=obj.trials, evaluation_name=obj.evaluation_name, ) diff --git a/weave/flow/prompt/prompt.py b/weave/flow/prompt/prompt.py index d6289412fa3f..c12a5b95dc3e 100644 --- a/weave/flow/prompt/prompt.py +++ b/weave/flow/prompt/prompt.py @@ -9,6 +9,7 @@ from pydantic import Field from rich.table import Table +from typing_extensions import Self from weave.flow.obj import Object from weave.flow.prompt.common import ROLE_COLORS, color_role @@ -17,6 +18,7 @@ from weave.trace.op import op from weave.trace.refs import ObjectRef from weave.trace.rich import pydantic_util +from weave.trace.vals import WeaveObject class Message(TypedDict): @@ -89,7 +91,7 @@ def format(self, **kwargs: Any) -> str: return self.content.format(**kwargs) @classmethod - def from_obj(cls, obj: Any) -> "StringPrompt": + def from_obj(cls, obj: WeaveObject) -> Self: prompt = cls(content=obj.content) prompt.name = obj.name prompt.description = obj.description @@ -117,7 +119,7 @@ def format(self, **kwargs: Any) -> list: return [self.format_message(m, **kwargs) for m in self.messages] @classmethod - def from_obj(cls, obj: Any) -> "MessagesPrompt": + def from_obj(cls, obj: WeaveObject) -> Self: prompt = cls(messages=obj.messages) prompt.name = obj.name prompt.description = obj.description @@ -424,7 +426,7 @@ def as_dict(self) -> dict[str, Any]: } @classmethod - def from_obj(cls, obj: Any) -> "EasyPrompt": + def from_obj(cls, obj: WeaveObject) -> Self: messages = obj.messages if hasattr(obj, "messages") else obj.data messages = [dict(m) for m in messages] config = dict(obj.config) @@ -438,7 +440,7 @@ def from_obj(cls, obj: Any) -> "EasyPrompt": ) @staticmethod - def load(fp: IO) -> "EasyPrompt": + def load(fp: IO) -> Self: if isinstance(fp, str): # Common mistake raise TypeError( "Prompt.load() takes a file-like object, not a string. Did you mean Prompt.e()?" @@ -448,7 +450,7 @@ def load(fp: IO) -> "EasyPrompt": return prompt @staticmethod - def load_file(filepath: Union[str, Path]) -> "Prompt": + def load_file(filepath: Union[str, Path]) -> Self: expanded_path = os.path.expanduser(str(filepath)) with open(expanded_path) as f: return EasyPrompt.load(f)