Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Jan 12, 2025
1 parent fd97147 commit 82b41ba
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 60 deletions.
8 changes: 4 additions & 4 deletions tests/trace/test_uri_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
@pytest.fixture(
params=[
"dataset",
# "evaluation",
# "string_prompt",
# "messages_prompt",
# "easy_prompt",
"evaluation",
"string_prompt",
"messages_prompt",
"easy_prompt",
]
)
def obj(request):
Expand Down
2 changes: 0 additions & 2 deletions weave/flow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import weave
from weave.flow.obj import Object
from weave.trace import objectify
from weave.trace.vals import WeaveObject, WeaveTable


Expand All @@ -17,7 +16,6 @@ def short_str(obj: Any, limit: int = 25) -> str:
return str_val


@objectify.register
class Dataset(Object):
"""
Dataset object with easy saving and automatic versioning
Expand Down
2 changes: 0 additions & 2 deletions weave/flow/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
get_scorer_attributes,
transpose,
)
from weave.trace import objectify
from weave.trace.env import get_weave_parallelism
from weave.trace.errors import OpCallError
from weave.trace.isinstance import weave_isinstance
Expand Down Expand Up @@ -59,7 +58,6 @@ class EvaluationResults(Object):
ScorerLike = Union[Callable, Op, Scorer]


@objectify.register
class Evaluation(Object):
"""
Sets up an evaluation which includes a set of scorers and a dataset.
Expand Down
11 changes: 8 additions & 3 deletions weave/flow/obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,15 @@ class Object(BaseModel):

@classmethod
def from_uri(cls, uri: str) -> Self:
"""It's up to the subclass to implement this!
if not hasattr(cls, "from_obj"):
raise NotImplementedError(
"Subclass must implement `from_obj` to support deserialization from a URI."
)

Using the @objectify.register decorator will also implement this automatically."""
raise NotImplementedError
import weave

obj = weave.ref(uri).get()
return cls.from_obj(obj)

# This is a "wrap" validator meaning we can run our own logic before
# and after the standard pydantic validation.
Expand Down
4 changes: 0 additions & 4 deletions weave/flow/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from weave.flow.obj import Object
from weave.flow.prompt.common import ROLE_COLORS, color_role
from weave.trace import objectify
from weave.trace.api import publish as weave_publish
from weave.trace.op import op
from weave.trace.refs import ObjectRef
Expand Down Expand Up @@ -79,7 +78,6 @@ def format(self, **kwargs: Any) -> Any:
raise NotImplementedError("Subclasses must implement format()")


@objectify.register
class StringPrompt(Prompt):
content: str = ""

Expand All @@ -98,7 +96,6 @@ def from_obj(cls, obj: WeaveObject) -> Self:
return prompt


@objectify.register
class MessagesPrompt(Prompt):
messages: list[dict] = Field(default_factory=list)

Expand Down Expand Up @@ -126,7 +123,6 @@ def from_obj(cls, obj: WeaveObject) -> Self:
return prompt


@objectify.register
class EasyPrompt(UserList, Prompt):
data: list = Field(default_factory=list)
config: dict = Field(default_factory=dict)
Expand Down
40 changes: 4 additions & 36 deletions weave/trace/objectify.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable

if TYPE_CHECKING:
from weave.trace.vals import WeaveObject

_registry: dict[str, type] = {}


T = TypeVar("T")


def register(cls: type[T]) -> type[T]:
"""Decorator to register a class with the objectify function.
Registered classes will be able to be deserialized directly into their base objects
instead of into a WeaveObject."""
_registry[cls.__name__] = cls

@runtime_checkable
class WeaveConvertible(Protocol):
@classmethod
def from_uri(cls: type[T], uri: str) -> T:
import weave

obj = weave.ref(uri).get()
if isinstance(obj, cls):
return obj
return cls(**obj.model_dump())

cls.from_uri = from_uri

return cls


def objectify(obj: WeaveObject) -> Any:
if not (cls_name := getattr(obj, "_class_name", None)):
return obj

if cls_name not in _registry:
return obj

cls = _registry[cls_name]
if hasattr(cls, "from_uri"):
return cls.from_uri(obj.ref.uri())

return obj
def from_obj(cls, obj: WeaveObject) -> Any: ...
6 changes: 3 additions & 3 deletions weave/trace/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def uri(self) -> str:
u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra)
return u

def get(self, *, objectify: bool = True) -> Any:
def get(self) -> Any:
# Move import here so that it only happens when the function is called.
# This import is invalid in the trace server and represents a dependency
# that should be removed.
Expand All @@ -171,7 +171,7 @@ def get(self, *, objectify: bool = True) -> Any:

gc = get_weave_client()
if gc is not None:
return gc.get(self, objectify=objectify)
return gc.get(self)

# Special case: If the user is attempting to fetch an object but has not
# yet initialized the client, we can initialize a client to
Expand All @@ -181,7 +181,7 @@ def get(self, *, objectify: bool = True) -> Any:
f"{self.entity}/{self.project}", ensure_project_exists=False
)
try:
res = init_client.client.get(self, objectify=objectify)
res = init_client.client.get(self)
finally:
init_client.reset()
return res
Expand Down
7 changes: 3 additions & 4 deletions weave/trace/vals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from pydantic import v1 as pydantic_v1

from weave.trace import box
from weave.trace import objectify as objectify_module
from weave.trace.context.tests_context import get_raise_on_captured_errors
from weave.trace.context.weave_client_context import get_weave_client
from weave.trace.errors import InternalError
from weave.trace.object_preparers import prepare_obj
from weave.trace.object_record import ObjectRecord
from weave.trace.objectify import WeaveConvertible
from weave.trace.op import is_op, maybe_bind_method
from weave.trace.refs import (
DICT_KEY_EDGE_NAME,
Expand Down Expand Up @@ -619,7 +619,6 @@ def make_trace_obj(
server: TraceServerInterface,
root: Optional[Traceable],
parent: Any = None,
objectify: bool = True,
) -> Any:
if isinstance(val, Traceable):
# If val is a WeaveTable, we want to refer to it via the outer object
Expand Down Expand Up @@ -710,8 +709,8 @@ def make_trace_obj(
if not isinstance(val, Traceable):
if isinstance(val, ObjectRecord):
obj = WeaveObject(val, ref=new_ref, server=server, root=root, parent=parent)
if objectify:
return objectify_module.objectify(obj)
if isinstance(obj, WeaveConvertible):
return obj.from_obj(obj)
return obj
elif isinstance(val, list):
return WeaveList(val, ref=new_ref, server=server, root=root, parent=parent)
Expand Down
4 changes: 2 additions & 2 deletions weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def save(self, val: Any, name: str, branch: str = "latest") -> Any:
return self.get(ref)

@trace_sentry.global_trace_sentry.watch()
def get(self, ref: ObjectRef, *, objectify: bool = True) -> Any:
def get(self, ref: ObjectRef) -> Any:
project_id = f"{ref.entity}/{ref.project}"
try:
read_res = self.server.obj_read(
Expand Down Expand Up @@ -707,7 +707,7 @@ def get(self, ref: ObjectRef, *, objectify: bool = True) -> Any:

val = from_json(data, project_id, self.server)

return make_trace_obj(val, ref, self.server, None, objectify=objectify)
return make_trace_obj(val, ref, self.server, None)

################ Query API ################

Expand Down

0 comments on commit 82b41ba

Please sign in to comment.