From 82b41ba39a50fab239af59c434167e6987e905f9 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Sun, 12 Jan 2025 18:26:20 -0500 Subject: [PATCH] test --- tests/trace/test_uri_get.py | 8 ++++---- weave/flow/dataset.py | 2 -- weave/flow/eval.py | 2 -- weave/flow/obj.py | 11 +++++++--- weave/flow/prompt/prompt.py | 4 ---- weave/trace/objectify.py | 40 ++++--------------------------------- weave/trace/refs.py | 6 +++--- weave/trace/vals.py | 7 +++---- weave/trace/weave_client.py | 4 ++-- 9 files changed, 24 insertions(+), 60 deletions(-) diff --git a/tests/trace/test_uri_get.py b/tests/trace/test_uri_get.py index 7554eba5a4b8..732e82e3d093 100644 --- a/tests/trace/test_uri_get.py +++ b/tests/trace/test_uri_get.py @@ -6,10 +6,10 @@ @pytest.fixture( params=[ "dataset", - # "evaluation", - # "string_prompt", - # "messages_prompt", - # "easy_prompt", + "evaluation", + "string_prompt", + "messages_prompt", + "easy_prompt", ] ) def obj(request): diff --git a/weave/flow/dataset.py b/weave/flow/dataset.py index 59a21fb573db..fc05dcd519d8 100644 --- a/weave/flow/dataset.py +++ b/weave/flow/dataset.py @@ -6,7 +6,6 @@ import weave from weave.flow.obj import Object -from weave.trace import objectify from weave.trace.vals import WeaveObject, WeaveTable @@ -17,7 +16,6 @@ def short_str(obj: Any, limit: int = 25) -> str: return str_val -@objectify.register class Dataset(Object): """ Dataset object with easy saving and automatic versioning diff --git a/weave/flow/eval.py b/weave/flow/eval.py index 731458d676df..193741b33a47 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -28,7 +28,6 @@ get_scorer_attributes, transpose, ) -from weave.trace import objectify from weave.trace.env import get_weave_parallelism from weave.trace.errors import OpCallError from weave.trace.isinstance import weave_isinstance @@ -59,7 +58,6 @@ class EvaluationResults(Object): ScorerLike = Union[Callable, Op, Scorer] -@objectify.register class Evaluation(Object): """ Sets up an evaluation which includes a set of scorers and a dataset. diff --git a/weave/flow/obj.py b/weave/flow/obj.py index 7247a2acbf74..f3691b56409f 100644 --- a/weave/flow/obj.py +++ b/weave/flow/obj.py @@ -54,10 +54,15 @@ class Object(BaseModel): @classmethod def from_uri(cls, uri: str) -> Self: - """It's up to the subclass to implement this! + if not hasattr(cls, "from_obj"): + raise NotImplementedError( + "Subclass must implement `from_obj` to support deserialization from a URI." + ) - Using the @objectify.register decorator will also implement this automatically.""" - raise NotImplementedError + import weave + + obj = weave.ref(uri).get() + return cls.from_obj(obj) # This is a "wrap" validator meaning we can run our own logic before # and after the standard pydantic validation. diff --git a/weave/flow/prompt/prompt.py b/weave/flow/prompt/prompt.py index c12a5b95dc3e..6068012e43ac 100644 --- a/weave/flow/prompt/prompt.py +++ b/weave/flow/prompt/prompt.py @@ -13,7 +13,6 @@ from weave.flow.obj import Object from weave.flow.prompt.common import ROLE_COLORS, color_role -from weave.trace import objectify from weave.trace.api import publish as weave_publish from weave.trace.op import op from weave.trace.refs import ObjectRef @@ -79,7 +78,6 @@ def format(self, **kwargs: Any) -> Any: raise NotImplementedError("Subclasses must implement format()") -@objectify.register class StringPrompt(Prompt): content: str = "" @@ -98,7 +96,6 @@ def from_obj(cls, obj: WeaveObject) -> Self: return prompt -@objectify.register class MessagesPrompt(Prompt): messages: list[dict] = Field(default_factory=list) @@ -126,7 +123,6 @@ def from_obj(cls, obj: WeaveObject) -> Self: return prompt -@objectify.register class EasyPrompt(UserList, Prompt): data: list = Field(default_factory=list) config: dict = Field(default_factory=dict) diff --git a/weave/trace/objectify.py b/weave/trace/objectify.py index 5c53e3fd19eb..9669a4509e44 100644 --- a/weave/trace/objectify.py +++ b/weave/trace/objectify.py @@ -1,46 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable if TYPE_CHECKING: from weave.trace.vals import WeaveObject -_registry: dict[str, type] = {} - - T = TypeVar("T") -def register(cls: type[T]) -> type[T]: - """Decorator to register a class with the objectify function. - - Registered classes will be able to be deserialized directly into their base objects - instead of into a WeaveObject.""" - _registry[cls.__name__] = cls - +@runtime_checkable +class WeaveConvertible(Protocol): @classmethod - def from_uri(cls: type[T], uri: str) -> T: - import weave - - obj = weave.ref(uri).get() - if isinstance(obj, cls): - return obj - return cls(**obj.model_dump()) - - cls.from_uri = from_uri - - return cls - - -def objectify(obj: WeaveObject) -> Any: - if not (cls_name := getattr(obj, "_class_name", None)): - return obj - - if cls_name not in _registry: - return obj - - cls = _registry[cls_name] - if hasattr(cls, "from_uri"): - return cls.from_uri(obj.ref.uri()) - - return obj + def from_obj(cls, obj: WeaveObject) -> Any: ... diff --git a/weave/trace/refs.py b/weave/trace/refs.py index 3102f0e7ef42..91a4f159e7fb 100644 --- a/weave/trace/refs.py +++ b/weave/trace/refs.py @@ -162,7 +162,7 @@ def uri(self) -> str: u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra) return u - def get(self, *, objectify: bool = True) -> Any: + def get(self) -> Any: # Move import here so that it only happens when the function is called. # This import is invalid in the trace server and represents a dependency # that should be removed. @@ -171,7 +171,7 @@ def get(self, *, objectify: bool = True) -> Any: gc = get_weave_client() if gc is not None: - return gc.get(self, objectify=objectify) + return gc.get(self) # Special case: If the user is attempting to fetch an object but has not # yet initialized the client, we can initialize a client to @@ -181,7 +181,7 @@ def get(self, *, objectify: bool = True) -> Any: f"{self.entity}/{self.project}", ensure_project_exists=False ) try: - res = init_client.client.get(self, objectify=objectify) + res = init_client.client.get(self) finally: init_client.reset() return res diff --git a/weave/trace/vals.py b/weave/trace/vals.py index 09240a4d3642..c3d06feec96a 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -11,12 +11,12 @@ from pydantic import v1 as pydantic_v1 from weave.trace import box -from weave.trace import objectify as objectify_module from weave.trace.context.tests_context import get_raise_on_captured_errors from weave.trace.context.weave_client_context import get_weave_client from weave.trace.errors import InternalError from weave.trace.object_preparers import prepare_obj from weave.trace.object_record import ObjectRecord +from weave.trace.objectify import WeaveConvertible from weave.trace.op import is_op, maybe_bind_method from weave.trace.refs import ( DICT_KEY_EDGE_NAME, @@ -619,7 +619,6 @@ def make_trace_obj( server: TraceServerInterface, root: Optional[Traceable], parent: Any = None, - objectify: bool = True, ) -> Any: if isinstance(val, Traceable): # If val is a WeaveTable, we want to refer to it via the outer object @@ -710,8 +709,8 @@ def make_trace_obj( if not isinstance(val, Traceable): if isinstance(val, ObjectRecord): obj = WeaveObject(val, ref=new_ref, server=server, root=root, parent=parent) - if objectify: - return objectify_module.objectify(obj) + if isinstance(obj, WeaveConvertible): + return obj.from_obj(obj) return obj elif isinstance(val, list): return WeaveList(val, ref=new_ref, server=server, root=root, parent=parent) diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index dfe9a8ec9809..0754b9f57eb5 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -663,7 +663,7 @@ def save(self, val: Any, name: str, branch: str = "latest") -> Any: return self.get(ref) @trace_sentry.global_trace_sentry.watch() - def get(self, ref: ObjectRef, *, objectify: bool = True) -> Any: + def get(self, ref: ObjectRef) -> Any: project_id = f"{ref.entity}/{ref.project}" try: read_res = self.server.obj_read( @@ -707,7 +707,7 @@ def get(self, ref: ObjectRef, *, objectify: bool = True) -> Any: val = from_json(data, project_id, self.server) - return make_trace_obj(val, ref, self.server, None, objectify=objectify) + return make_trace_obj(val, ref, self.server, None) ################ Query API ################