Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Jan 10, 2025
1 parent 72a5d1d commit fd97147
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
3 changes: 1 addition & 2 deletions weave/flow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ class Dataset(Object):

@classmethod
def from_obj(cls, obj: WeaveObject) -> Self:
rows = obj.rows
return cls(rows=rows)
return cls(rows=obj.rows)

@field_validator("rows", mode="before")
def convert_to_table(cls, rows: Any) -> weave.Table:
Expand Down
1 change: 1 addition & 0 deletions weave/flow/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def from_obj(cls, obj: WeaveObject) -> Self:
dataset=obj.dataset,
scorers=obj.scorers,
preprocess_model_input=obj.preprocess_model_input,
trials=obj.trials,
evaluation_name=obj.evaluation_name,
)

Expand Down
12 changes: 7 additions & 5 deletions weave/flow/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pydantic import Field
from rich.table import Table
from typing_extensions import Self

from weave.flow.obj import Object
from weave.flow.prompt.common import ROLE_COLORS, color_role
Expand All @@ -17,6 +18,7 @@
from weave.trace.op import op
from weave.trace.refs import ObjectRef
from weave.trace.rich import pydantic_util
from weave.trace.vals import WeaveObject


class Message(TypedDict):
Expand Down Expand Up @@ -89,7 +91,7 @@ def format(self, **kwargs: Any) -> str:
return self.content.format(**kwargs)

@classmethod
def from_obj(cls, obj: Any) -> "StringPrompt":
def from_obj(cls, obj: WeaveObject) -> Self:
prompt = cls(content=obj.content)
prompt.name = obj.name
prompt.description = obj.description
Expand Down Expand Up @@ -117,7 +119,7 @@ def format(self, **kwargs: Any) -> list:
return [self.format_message(m, **kwargs) for m in self.messages]

@classmethod
def from_obj(cls, obj: Any) -> "MessagesPrompt":
def from_obj(cls, obj: WeaveObject) -> Self:
prompt = cls(messages=obj.messages)
prompt.name = obj.name
prompt.description = obj.description
Expand Down Expand Up @@ -424,7 +426,7 @@ def as_dict(self) -> dict[str, Any]:
}

@classmethod
def from_obj(cls, obj: Any) -> "EasyPrompt":
def from_obj(cls, obj: WeaveObject) -> Self:
messages = obj.messages if hasattr(obj, "messages") else obj.data
messages = [dict(m) for m in messages]
config = dict(obj.config)
Expand All @@ -438,7 +440,7 @@ def from_obj(cls, obj: Any) -> "EasyPrompt":
)

@staticmethod
def load(fp: IO) -> "EasyPrompt":
def load(fp: IO) -> Self:
if isinstance(fp, str): # Common mistake
raise TypeError(
"Prompt.load() takes a file-like object, not a string. Did you mean Prompt.e()?"
Expand All @@ -448,7 +450,7 @@ def load(fp: IO) -> "EasyPrompt":
return prompt

@staticmethod
def load_file(filepath: Union[str, Path]) -> "Prompt":
def load_file(filepath: Union[str, Path]) -> Self:
expanded_path = os.path.expanduser(str(filepath))
with open(expanded_path) as f:
return EasyPrompt.load(f)
Expand Down

0 comments on commit fd97147

Please sign in to comment.