From e41ce26b982eaaaa57a625d6c293f10fae158de6 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 08:46:43 -0800 Subject: [PATCH 001/120] Initial Migration --- .../007_add_refs_to_feedback.down.sql | 5 ++ .../007_add_refs_to_feedback.up.sql | 48 +++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 weave/trace_server/migrations/007_add_refs_to_feedback.down.sql create mode 100644 weave/trace_server/migrations/007_add_refs_to_feedback.up.sql diff --git a/weave/trace_server/migrations/007_add_refs_to_feedback.down.sql b/weave/trace_server/migrations/007_add_refs_to_feedback.down.sql new file mode 100644 index 00000000000..7504258720d --- /dev/null +++ b/weave/trace_server/migrations/007_add_refs_to_feedback.down.sql @@ -0,0 +1,5 @@ +ALTER TABLE feedback + DROP COLUMN annotation_ref, + DROP COLUMN runnable_ref, + DROP COLUMN call_ref, + DROP COLUMN trigger_ref; diff --git a/weave/trace_server/migrations/007_add_refs_to_feedback.up.sql b/weave/trace_server/migrations/007_add_refs_to_feedback.up.sql new file mode 100644 index 00000000000..f5c7c296abd --- /dev/null +++ b/weave/trace_server/migrations/007_add_refs_to_feedback.up.sql @@ -0,0 +1,48 @@ +/* +This migration adds the following columns to the feedback table: +- `annotation_ref`: The ref pointing to the annotation definition for this feedback. +- `runnable_ref`: The ref pointing to the runnable definition for this feedback. +- `call_ref`: The ref pointing to the resulting call associated with generating this feedback. +- `trigger_ref`: The ref pointing to the trigger definition which resulted in this feedback. + +We are enhancing the feedback table to support richer payloads - specifically those generated +from scoring functions and/or human annotations. These additional columns allow us +to join/query/filter on the referenced entities (aka foreign keys) without loading +the entire payload into memory. + +Note, there are two classes of feedback types: +- `wandb.annotation.*`: Feedback generated by a human annotator. + - Here, `*` is a placeholder for the name of the annotation field (which is expected to be the name part of the annotaiton ref) +- `wandb.runnable.*`: Feedback generated by a machine scoring function. + - Here, `*` is a placeholder for the name of the runnable (which is expected to be the name part of the runnable ref) + +Furthermore, the fields are mostly mutually exclusive, where: +- `wandb.annotation.*` feedback will have `annotation_ref` populated. +- `wandb.runnable.*` feedback will have `runnable_ref` populated and optionally (`call_ref` and `trigger_ref`). +However, it is conceivable that in the future a user might want to use a runnable to generate feedback that +corresponds to an annotation field! + +*/ +ALTER TABLE feedback + /* + `annotation_ref`: The ref pointing to the annotation definition for this feedback. + Expected to be present on any feedback type starting with `wandb.annotation`. + */ + ADD COLUMN annotation_ref Nullable(String) DEFAULT NULL, + /* + `runnable_ref`: The ref pointing to the runnable definition for this feedback. + This can be an op, a configured action, etc... + Expected to be present on any feedback type starting with `wandb.runnable`. + */ + ADD COLUMN runnable_ref Nullable(String) DEFAULT NULL, + /* + `call_ref`: The ref pointing to the resulting call associated with generating this feedback. + Expected (but not required) to be present on any feedback that has `runnable_ref` as a + call-producing op. + */ + ADD COLUMN call_ref Nullable(String) DEFAULT NULL, + /* + `trigger_ref`: The ref pointing to the trigger definition which resulted in this feedback. + Will be present when the runnable_ref has been executed by a trigger, not a human/small batch job. + */ + ADD COLUMN trigger_ref Nullable(String) DEFAULT NULL; From 3b6cef932d0e5f6cee5e6b6dd76d7cadd92abca5 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 10:16:33 -0800 Subject: [PATCH 002/120] Interface and basic validation --- weave/trace_server/feedback.py | 88 ++++++++++++++++++++ weave/trace_server/trace_server_interface.py | 12 +++ 2 files changed, 100 insertions(+) diff --git a/weave/trace_server/feedback.py b/weave/trace_server/feedback.py index 3283987f48e..55813115d43 100644 --- a/weave/trace_server/feedback.py +++ b/weave/trace_server/feedback.py @@ -1,5 +1,8 @@ +from typing import Optional, Tuple, Type, TypeVar, Union, overload + from pydantic import BaseModel, ValidationError +from weave.trace_server import refs_internal as ri from weave.trace_server import trace_server_interface as tsi from weave.trace_server.errors import InvalidRequest from weave.trace_server.orm import Column, Table @@ -28,6 +31,49 @@ "wandb.note.1": tsi.FeedbackPayloadNoteReq, } +ANNOTATION_FEEDBACK_TYPE_PREFIX = "wandb.annotation" +RUNNABLE_FEEDBACK_TYPE_PREFIX = "wandb.runnable" + +T = TypeVar( + "T", ri.InternalObjectRef, ri.InternalTableRef, ri.InternalCallRef, ri.InternalOpRef +) + + +@overload +def _ensure_ref_is_valid( + ref: str, expected_type: None = None +) -> Union[ri.InternalObjectRef, ri.InternalTableRef, ri.InternalCallRef]: ... + + +@overload +def _ensure_ref_is_valid( + ref: str, + expected_type: Tuple[Type[T], ...], +) -> T: ... + + +def _ensure_ref_is_valid( + ref: str, expected_type: Optional[Tuple[Type, ...]] = None +) -> Union[ri.InternalObjectRef, ri.InternalTableRef, ri.InternalCallRef]: + """Validates and parses an internal URI reference. + + Args: + ref: The reference string to validate + expected_type: Optional tuple of expected reference types + + Returns: + The parsed internal reference object + + Raises: + InvalidRequest: If the reference is invalid or doesn't match expected_type + """ + parsed_ref = ri.parse_internal_uri(ref) + if expected_type and not isinstance(parsed_ref, expected_type): + raise InvalidRequest( + f"Invalid ref: {ref}, expected {(t.__name__ for t in expected_type)}" + ) + return parsed_ref + def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None: payload_schema = FEEDBACK_PAYLOAD_SCHEMAS.get(req.feedback_type) @@ -39,6 +85,48 @@ def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None: f"Invalid payload for feedback_type {req.feedback_type}: {e}" ) + # Validate the required fields for the feedback type. + if req.feedback_type.startswith(ANNOTATION_FEEDBACK_TYPE_PREFIX): + if not req.feedback_type.startswith(ANNOTATION_FEEDBACK_TYPE_PREFIX + "."): + raise InvalidRequest( + f"Invalid annotation feedback type: {req.feedback_type}" + ) + type_subname = req.feedback_type[len(ANNOTATION_FEEDBACK_TYPE_PREFIX) + 1 :] + if not req.annotation_ref: + raise InvalidRequest("annotation_ref is required for annotation feedback") + annotation_ref = _ensure_ref_is_valid( + req.annotation_ref, (ri.InternalObjectRef,) + ) + if annotation_ref.name != type_subname: + raise InvalidRequest( + f"annotation_ref must point to an object with name {type_subname}" + ) + elif req.feedback_type.startswith(RUNNABLE_FEEDBACK_TYPE_PREFIX): + if not req.feedback_type.startswith(RUNNABLE_FEEDBACK_TYPE_PREFIX + "."): + raise InvalidRequest(f"Invalid runnable feedback type: {req.feedback_type}") + type_subname = req.feedback_type[len(RUNNABLE_FEEDBACK_TYPE_PREFIX) + 1 :] + if not req.runnable_ref: + raise InvalidRequest("runnable_ref is required for runnable feedback") + runnable_ref = _ensure_ref_is_valid( + req.runnable_ref, (ri.InternalOpRef, ri.InternalObjectRef) + ) + if runnable_ref.name != type_subname: + raise InvalidRequest( + f"runnable_ref must point to an object with name {type_subname}" + ) + if isinstance(runnable_ref, ri.InternalOpRef) and not req.call_ref: + raise InvalidRequest("call_ref is required for runnable feedback on ops") + + # Validate the ref formats (we could even query the DB to ensure they exist and are valid) + if req.annotation_ref: + _ensure_ref_is_valid(req.annotation_ref, (ri.InternalObjectRef,)) + if req.runnable_ref: + _ensure_ref_is_valid(req.runnable_ref, (ri.InternalOpRef, ri.InternalObjectRef)) + if req.call_ref: + _ensure_ref_is_valid(req.call_ref, (ri.InternalCallRef,)) + if req.trigger_ref: + _ensure_ref_is_valid(req.trigger_ref, (ri.InternalObjectRef,)) + MESSAGE_INVALID_FEEDBACK_PURGE = ( "Can only purge feedback by specifying one or more feedback ids" diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index af8778dc1d3..88b849d644c 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -686,6 +686,18 @@ class FeedbackCreateReq(BaseModel): } ] ) + annotation_ref: Optional[str] = Field( + default=None, examples=["weave:///entity/project/object/name:digest"] + ) + runnable_ref: Optional[str] = Field( + default=None, examples=["weave:///entity/project/op/name:digest"] + ) + call_ref: Optional[str] = Field( + default=None, examples=["weave:///entity/project/call/call_id"] + ) + trigger_ref: Optional[str] = Field( + default=None, examples=["weave:///entity/project/object/name:digest"] + ) # wb_user_id is automatically populated by the server wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) From 3ae1ca21bc3599057b25a4a39fcbb83273661124 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 12:40:00 -0800 Subject: [PATCH 003/120] Added tests and Assertions --- tests/trace/test_feedback.py | 251 +++++++++++++++++++++++++++++++++ weave/trace_server/feedback.py | 48 ++++++- 2 files changed, 295 insertions(+), 4 deletions(-) diff --git a/tests/trace/test_feedback.py b/tests/trace/test_feedback.py index cc05dccd624..06ed7033ab6 100644 --- a/tests/trace/test_feedback.py +++ b/tests/trace/test_feedback.py @@ -1,5 +1,9 @@ import pytest +from weave.trace.weave_client import WeaveClient +from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.errors import InvalidRequest + def test_client_feedback(client) -> None: feedbacks = client.get_feedback() @@ -62,3 +66,250 @@ def test_custom_feedback(client) -> None: with pytest.raises(ValueError): trace_object.feedback.add("wandb.trying_to_use_reserved_prefix", value=1) + + +def test_annotation_feedback(client: WeaveClient) -> None: + project_id = client._project_id() + column_name = "column_name" + feedback_type = f"wandb.annotation.{column_name}" + weave_ref = f"weave:///{project_id}/call/cal_id_123" + annotation_ref = f"weave:///{project_id}/object/{column_name}:obj_id_123" + payload = {"value": 1} + + # Case 1: Errors with no name in type (dangle or char len 0) + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type="wandb.annotation", # No name + payload=payload, + annotation_ref=annotation_ref, + ) + ) + + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type="wandb.annotation.", # Trailing period + payload=payload, + annotation_ref=annotation_ref, + ) + ) + # Case 2: Errors with incorrect ref string format + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type=feedback_type, + payload=payload, + annotation_ref=f"weave:///{project_id}/object/{column_name}", # No digest + ) + ) + # Case 3: Errors with name mismatch + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type=feedback_type + "_wrong_name", + payload=payload, + annotation_ref=annotation_ref, + ) + ) + # Case 4: Errors if annotation ref is present but incorrect type + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type="not.annotation", + payload=payload, + annotation_ref=f"weave:///{project_id}/op/{column_name}:obj_id_123", + ) + ) + + # Case 5: Invalid payload + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type=feedback_type, + payload={"not": "a valid payload"}, + annotation_ref=annotation_ref, + ) + ) + + # Success + create_res = client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type=feedback_type, + payload=payload, + annotation_ref=annotation_ref, + ) + ) + assert create_res.id != None + # Correct Query Result Payload + query_res = client.server.feedback_query( + tsi.FeedbackQueryReq( + project_id=project_id, + ) + ) + assert len(query_res.result) == 1 + assert query_res.result[0] == { + "id": create_res.id, + "project_id": project_id, + "weave_ref": weave_ref, + "wb_user_id": "shawn", + "creator": None, + "created_at": create_res.created_at.isoformat().replace("T", " "), + "feedback_type": feedback_type, + "payload": payload, + } + + +def test_runnable_feedback(client: WeaveClient) -> None: + """Test feedback creation with runnable references.""" + project_id = client._project_id() + runnable_name = "runnable_name" + feedback_type = f"wandb.runnable.{runnable_name}" + weave_ref = f"weave:///{project_id}/call/cal_id_123" + runnable_ref = f"weave:///{project_id}/op/{runnable_name}:op_id_123" + call_ref = f"weave:///{project_id}/call/call_id_123" + trigger_ref = f"weave:///{project_id}/object/{runnable_name}:trigger_id_123" + payload = {"output": 1} + + # Case 1: Errors with no name in type (dangle or char len 0) + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type="wandb.runnable", # No name + payload=payload, + runnable_ref=runnable_ref, + ) + ) + + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type="wandb.runnable.", # Trailing period + payload=payload, + runnable_ref=runnable_ref, + ) + ) + + # Case 2: Errors with incorrect ref string format + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type=feedback_type, + payload=payload, + runnable_ref=f"weave:///{project_id}/op/{runnable_name}", # No digest + ) + ) + + # Case 3: Errors with name mismatch + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type=feedback_type + "_wrong_name", + payload=payload, + runnable_ref=runnable_ref, + ) + ) + + # Case 4: Errors if runnable ref is present but incorrect type + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type="not.runnable", # Wrong type + payload=payload, + runnable_ref=runnable_ref, # Wrong type + ) + ) + + # Case 5: Errors if call ref is present but incorrect type + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type="not.runnable", # Wrong type + payload=payload, + call_ref=call_ref, + ) + ) + + # Case 6: Errors if trigger ref is present but incorrect type + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type="not.runnable", + payload=payload, + trigger_ref=trigger_ref, + ) + ) + + # Case 7: Invalid payload + with pytest.raises(InvalidRequest): + client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type=feedback_type, + payload={"not": "a valid payload"}, + runnable_ref=runnable_ref, + call_ref=call_ref, + trigger_ref=trigger_ref, + ) + ) + + # Success + create_res = client.server.feedback_create( + tsi.FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type=feedback_type, + payload=payload, + runnable_ref=runnable_ref, + call_ref=call_ref, + trigger_ref=trigger_ref, + ) + ) + assert create_res.id is not None + + # Verify Query Result Payload + query_res = client.server.feedback_query( + tsi.FeedbackQueryReq( + project_id=project_id, + ) + ) + assert len(query_res.result) == 1 + assert query_res.result[0] == { + "id": create_res.id, + "project_id": project_id, + "weave_ref": weave_ref, + "wb_user_id": "shawn", + "creator": None, + "created_at": create_res.created_at.isoformat().replace("T", " "), + "feedback_type": feedback_type, + "payload": payload, + } diff --git a/weave/trace_server/feedback.py b/weave/trace_server/feedback.py index 55813115d43..c6ded0b2671 100644 --- a/weave/trace_server/feedback.py +++ b/weave/trace_server/feedback.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Type, TypeVar, Union, overload +from typing import Any, Optional, Tuple, Type, TypeVar, Union, overload from pydantic import BaseModel, ValidationError @@ -34,6 +34,19 @@ ANNOTATION_FEEDBACK_TYPE_PREFIX = "wandb.annotation" RUNNABLE_FEEDBACK_TYPE_PREFIX = "wandb.runnable" + +# Making the decision to use `value` & `payload` as nested keys so that +# we can: +# 1. Add more fields in the future without breaking changes +# 2. Support primitive values for annotation feedback that still schema +class AnnotationPayloadSchema(BaseModel): + value: Any + + +class RunnablePayloadSchema(BaseModel): + output: Any + + T = TypeVar( "T", ri.InternalObjectRef, ri.InternalTableRef, ri.InternalCallRef, ri.InternalOpRef ) @@ -67,7 +80,10 @@ def _ensure_ref_is_valid( Raises: InvalidRequest: If the reference is invalid or doesn't match expected_type """ - parsed_ref = ri.parse_internal_uri(ref) + try: + parsed_ref = ri.parse_internal_uri(ref) + except ValueError as e: + raise InvalidRequest(f"Invalid ref: {ref}, {e}") if expected_type and not isinstance(parsed_ref, expected_type): raise InvalidRequest( f"Invalid ref: {ref}, expected {(t.__name__ for t in expected_type)}" @@ -86,7 +102,9 @@ def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None: ) # Validate the required fields for the feedback type. - if req.feedback_type.startswith(ANNOTATION_FEEDBACK_TYPE_PREFIX): + is_annotation = req.feedback_type.startswith(ANNOTATION_FEEDBACK_TYPE_PREFIX) + is_runnable = req.feedback_type.startswith(RUNNABLE_FEEDBACK_TYPE_PREFIX) + if is_annotation: if not req.feedback_type.startswith(ANNOTATION_FEEDBACK_TYPE_PREFIX + "."): raise InvalidRequest( f"Invalid annotation feedback type: {req.feedback_type}" @@ -101,7 +119,17 @@ def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None: raise InvalidRequest( f"annotation_ref must point to an object with name {type_subname}" ) - elif req.feedback_type.startswith(RUNNABLE_FEEDBACK_TYPE_PREFIX): + try: + AnnotationPayloadSchema.model_validate(req.payload) + except ValidationError as e: + raise InvalidRequest( + f"Invalid payload for feedback_type {req.feedback_type}: {e}" + ) + elif req.annotation_ref: + raise InvalidRequest( + "annotation_ref is not allowed for non-annotation feedback" + ) + elif is_runnable: if not req.feedback_type.startswith(RUNNABLE_FEEDBACK_TYPE_PREFIX + "."): raise InvalidRequest(f"Invalid runnable feedback type: {req.feedback_type}") type_subname = req.feedback_type[len(RUNNABLE_FEEDBACK_TYPE_PREFIX) + 1 :] @@ -116,6 +144,18 @@ def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None: ) if isinstance(runnable_ref, ri.InternalOpRef) and not req.call_ref: raise InvalidRequest("call_ref is required for runnable feedback on ops") + try: + RunnablePayloadSchema.model_validate(req.payload) + except ValidationError as e: + raise InvalidRequest( + f"Invalid payload for feedback_type {req.feedback_type}: {e}" + ) + elif req.runnable_ref: + raise InvalidRequest("runnable_ref is not allowed for non-runnable feedback") + elif req.call_ref: + raise InvalidRequest("call_ref is not allowed for non-runnable feedback") + elif req.trigger_ref: + raise InvalidRequest("trigger_ref is not allowed for non-runnable feedback") # Validate the ref formats (we could even query the DB to ensure they exist and are valid) if req.annotation_ref: From 1dfcd7aceb1884eaebd51e41ad3daedb4b35d59f Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 13:05:05 -0800 Subject: [PATCH 004/120] Modify scorers and uptake changes - make initial test changes --- tests/trace/test_evaluations.py | 18 +++++++++++++----- weave/trace/feedback_types/score.py | 13 ------------- weave/trace/weave_client.py | 14 +++++--------- 3 files changed, 18 insertions(+), 27 deletions(-) delete mode 100644 weave/trace/feedback_types/score.py diff --git a/tests/trace/test_evaluations.py b/tests/trace/test_evaluations.py index d137a92d4ef..6fa52846286 100644 --- a/tests/trace/test_evaluations.py +++ b/tests/trace/test_evaluations.py @@ -10,7 +10,7 @@ from tests.trace.util import AnyIntMatcher from weave import Evaluation, Model from weave.scorers import Scorer -from weave.trace.feedback_types.score import SCORE_TYPE_NAME +from weave.trace.refs import CallRef from weave.trace.weave_client import get_ref from weave.trace_server import trace_server_interface as tsi @@ -935,6 +935,7 @@ def model_function(col1, col2): ), "No matches should be found for AnotherDummyScorer" +@pytest.mark.asyncio async def test_feedback_is_correctly_linked(client): @weave.op def predict(text: str) -> str: @@ -961,7 +962,14 @@ def score(text, model_output) -> bool: feedbacks = calls.calls[0].summary["weave"]["feedback"] assert len(feedbacks) == 1 feedback = feedbacks[0] - assert feedback["feedback_type"] == SCORE_TYPE_NAME - assert feedback["payload"]["name"] == "score" - assert feedback["payload"]["op_ref"] == get_ref(score).uri() - assert feedback["payload"]["results"] == True + assert feedback["feedback_type"] == "wandb.runnable.score" + assert feedback["payload"] == {"output": True} + assert feedback["runnable_ref"] == get_ref(score).uri() + assert ( + feedback["call_ref"] + == CallRef( + entity=client.entity, + project=client.project, + call_id=list(score.get_calls())[0].id, + ).uri() + ) diff --git a/weave/trace/feedback_types/score.py b/weave/trace/feedback_types/score.py deleted file mode 100644 index e15438658c7..00000000000 --- a/weave/trace/feedback_types/score.py +++ /dev/null @@ -1,13 +0,0 @@ -# This type is still "Beta" and the underlying payload might change as well. -# We're using "beta.1" to indicate that this is a pre-release version. -from typing import TypedDict - -SCORE_TYPE_NAME = "wandb.score.beta.1" - - -class ScoreTypePayload(TypedDict): - name: str - op_ref: str - call_ref: str - results: dict - # supervision: dict diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 7f4d1d88ef0..416164a90dc 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -18,7 +18,6 @@ from weave.trace.context import weave_client_context as weave_client_context from weave.trace.exception import exception_to_json_str from weave.trace.feedback import FeedbackQuery, RefFeedbackQuery -from weave.trace.feedback_types.score import SCORE_TYPE_NAME, ScoreTypePayload from weave.trace.object_record import ( ObjectRecord, dataclass_object_record, @@ -1106,8 +1105,6 @@ def _add_score( Outstanding questions: - Should we somehow include supervision (ie. the ground truth) in the payload? - - What should the shape of `ScoreTypePayload` be? Maybe we want the results to be top-level? - - What should we use for name? A standard "score" or the score name? """ # Parse the refs (acts as validation) call_ref = parse_uri(call_ref_uri) @@ -1132,18 +1129,17 @@ def _add_score( # # Prepare the supervision payload - payload: ScoreTypePayload = { - "name": score_name, - "op_ref": scorer_op_ref_uri, - "call_ref": scorer_call_ref_uri, - "results": results_json, + payload = { + "output": results_json, } freq = FeedbackCreateReq( project_id=self._project_id(), weave_ref=call_ref_uri, - feedback_type=SCORE_TYPE_NAME, # should this be score_name? + feedback_type="wandb.runnable." + score_name, payload=payload, + runnable_ref=scorer_op_ref_uri, + call_ref=scorer_call_ref_uri, ) response = self.server.feedback_create(freq) From 7d03c21a60efa75627c558084f26cb15bea4a55b Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 13:13:34 -0800 Subject: [PATCH 005/120] Implemented initial query-side improvements --- tests/trace/test_evaluations.py | 2 +- tests/trace/test_feedback.py | 8 ++++++++ weave/trace_server/feedback.py | 4 ++++ weave/trace_server/sqlite_trace_server.py | 4 ++++ weave/trace_server/trace_server_common.py | 4 ++++ 5 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/trace/test_evaluations.py b/tests/trace/test_evaluations.py index 6fa52846286..d20a82db852 100644 --- a/tests/trace/test_evaluations.py +++ b/tests/trace/test_evaluations.py @@ -970,6 +970,6 @@ def score(text, model_output) -> bool: == CallRef( entity=client.entity, project=client.project, - call_id=list(score.get_calls())[0].id, + id=list(score.calls())[0].id, ).uri() ) diff --git a/tests/trace/test_feedback.py b/tests/trace/test_feedback.py index 06ed7033ab6..9a432117499 100644 --- a/tests/trace/test_feedback.py +++ b/tests/trace/test_feedback.py @@ -171,6 +171,10 @@ def test_annotation_feedback(client: WeaveClient) -> None: "created_at": create_res.created_at.isoformat().replace("T", " "), "feedback_type": feedback_type, "payload": payload, + "annotation_ref": annotation_ref, + "runnable_ref": None, + "call_ref": None, + "trigger_ref": None, } @@ -312,4 +316,8 @@ def test_runnable_feedback(client: WeaveClient) -> None: "created_at": create_res.created_at.isoformat().replace("T", " "), "feedback_type": feedback_type, "payload": payload, + "annotation_ref": None, + "runnable_ref": runnable_ref, + "call_ref": call_ref, + "trigger_ref": trigger_ref, } diff --git a/weave/trace_server/feedback.py b/weave/trace_server/feedback.py index c6ded0b2671..72303c42183 100644 --- a/weave/trace_server/feedback.py +++ b/weave/trace_server/feedback.py @@ -22,6 +22,10 @@ Column("created_at", "datetime"), Column("feedback_type", "string"), Column("payload", "json", db_name="payload_dump"), + Column("annotation_ref", "string", nullable=True), + Column("runnable_ref", "string", nullable=True), + Column("call_ref", "string", nullable=True), + Column("trigger_ref", "string", nullable=True), ], ) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 1fd50cf1cb2..94c0d8c6ecc 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1007,6 +1007,10 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: "feedback_type": req.feedback_type, "payload": req.payload, "created_at": created_at, + "annotation_ref": req.annotation_ref, + "runnable_ref": req.runnable_ref, + "call_ref": req.call_ref, + "trigger_ref": req.trigger_ref, } conn, cursor = get_conn_cursor(self.db_path) with self.lock: diff --git a/weave/trace_server/trace_server_common.py b/weave/trace_server/trace_server_common.py index 991e31ff3cb..b55e5cfb10b 100644 --- a/weave/trace_server/trace_server_common.py +++ b/weave/trace_server/trace_server_common.py @@ -37,6 +37,10 @@ def make_feedback_query_req( "creator", "created_at", "wb_user_id", + "runnable_ref", + "call_ref", + "trigger_ref", + "annotation_ref", ], query=query, ) From 906a44871e1c43eab8e4ca4507af06416af83503 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 14:01:11 -0800 Subject: [PATCH 006/120] Implemented initial feedback query tests (failing) --- tests/trace/test_feedback.py | 91 +++++++++++++++++++++++++++++++++++- weave/trace/weave_client.py | 40 ++++++++++++++++ 2 files changed, 130 insertions(+), 1 deletion(-) diff --git a/tests/trace/test_feedback.py b/tests/trace/test_feedback.py index 9a432117499..5434fa68be4 100644 --- a/tests/trace/test_feedback.py +++ b/tests/trace/test_feedback.py @@ -1,6 +1,7 @@ import pytest -from weave.trace.weave_client import WeaveClient +import weave +from weave.trace.weave_client import WeaveClient, get_ref from weave.trace_server import trace_server_interface as tsi from weave.trace_server.errors import InvalidRequest @@ -321,3 +322,91 @@ def test_runnable_feedback(client: WeaveClient) -> None: "call_ref": call_ref, "trigger_ref": trigger_ref, } + + +def populate_feedback(client: WeaveClient) -> None: + @weave.op + def my_scorer(x: int, output: int) -> int: + expected = ["a", "b", "c", "d"][x] + return { + "model_output": output, + "expected": expected, + "match": output == expected, + } + + @weave.op + def my_model(x: int) -> str: + return [ + "a", + "x", # intentional "mistake" + "b", + "y", # intentional "mistake" + ][x] + + ids = [] + for x in range(4): + _, c = my_model.call(x) + ids.append(c.id) + c._apply_scorer(my_scorer) + + assert len(list(my_scorer.calls())) == 4 + assert len(list(my_model.calls())) == 4 + + return ids, my_scorer, my_model + + +def test_sort_by_feedback(client: WeaveClient) -> None: + """Test sorting by feedback.""" + ids, my_scorer, my_model = populate_feedback(client) + + for field, direction, id_exp in [ + ( + "feedback[wandb.runnable.my_scorer].payload.model_output", + "asc", + [ids[0], ids[2], ids[1], ids[3]], + ), + ( + "feedback[wandb.runnable.my_scorer].payload.model_output", + "desc", + [ids[3], ids[1], ids[2], ids[0]], + ), + ( + "feedback[wandb.runnable.my_scorer].payload.expected", + "asc", + [ids[3], ids[1], ids[2], ids[0]], + ), + ( + "feedback[wandb.runnable.my_scorer].payload.expected", + "desc", + [ids[0], ids[2], ids[1], ids[3]], + ), + ( + "feedback[wandb.runnable.my_scorer].payload.match", + "asc", + [ids[0], ids[2], ids[1], ids[3]], + ), + ( + "feedback[wandb.runnable.my_scorer].payload.match", + "desc", + [ids[3], ids[1], ids[2], ids[0]], + ), + ]: + calls = client.server.calls_query_stream( + tsi.CallsQueryReq( + project_id=client._project_id(), + filter=tsi.CallsFilter(op_name=[get_ref(my_model).uri()]), + sort_by=tsi.SortBy( + field=field, + direction=direction, + ), + include_feedback=True, + ) + ) + + assert [c.id for c in calls] == id_exp + + +def test_filter_by_feedback(client: WeaveClient) -> None: + """Test filtering by feedback.""" + ids, my_scorer, my_model = populate_feedback(client) + raise NotImplementedError diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 416164a90dc..df55a914854 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -294,6 +294,46 @@ def set_display_name(self, name: Optional[str]) -> None: def remove_display_name(self) -> None: self.set_display_name(None) + def _apply_scorer(self, scorer_op: Op) -> None: + """ + This is a private method that applies a scorer to a call and records the feedback. + In the near future, this will be made public, but for now it is only used internally + for testing. + + Before making this public, we should refactor such that the `predict_and_score` method + inside `eval.py` uses this method inside the scorer block. + + Current limitations: + - only works for ops (not Scorer class) + - no async support + - no context yet (ie. ground truth) + """ + client = weave_client_context.require_weave_client() + scorer_signature = scorer_op.signature + scorer_arg_names = list(scorer_signature.parameters.keys()) + score_args = {k: v for k, v in self.inputs.items() if k in scorer_arg_names} + if "output" in scorer_arg_names: + score_args["output"] = self.output + _, score_call = scorer_op.call(**score_args) + scorer_op_ref = get_ref(scorer_op) + if scorer_op_ref is None: + raise ValueError("Scorer op has no ref") + self_ref = get_ref(self) + if self_ref is None: + raise ValueError("Call has no ref") + score_name = scorer_op_ref.name + score_results = score_call.output + score_call_ref = get_ref(score_call) + if score_call_ref is None: + raise ValueError("Score call has no ref") + return client._add_score( + call_ref_uri=self_ref.uri(), + score_name=score_name, + score_results=score_results, + scorer_call_ref_uri=score_call_ref.uri(), + scorer_op_ref_uri=scorer_op_ref.uri(), + ) + class CallsIter: server: TraceServerInterface From c0dc64119523bca8c1e010e2fcca5bd9083abaad Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 14:05:22 -0800 Subject: [PATCH 007/120] Implemented initial feedback query tests (failing) --- tests/trace/test_feedback.py | 94 +++++++++++++++++++++++++++--------- 1 file changed, 72 insertions(+), 22 deletions(-) diff --git a/tests/trace/test_feedback.py b/tests/trace/test_feedback.py index 5434fa68be4..21fba1528fa 100644 --- a/tests/trace/test_feedback.py +++ b/tests/trace/test_feedback.py @@ -359,37 +359,19 @@ def test_sort_by_feedback(client: WeaveClient) -> None: """Test sorting by feedback.""" ids, my_scorer, my_model = populate_feedback(client) - for field, direction, id_exp in [ + for field, asc_ids, desc_ids in [ ( "feedback[wandb.runnable.my_scorer].payload.model_output", - "asc", [ids[0], ids[2], ids[1], ids[3]], ), - ( - "feedback[wandb.runnable.my_scorer].payload.model_output", - "desc", - [ids[3], ids[1], ids[2], ids[0]], - ), ( "feedback[wandb.runnable.my_scorer].payload.expected", - "asc", [ids[3], ids[1], ids[2], ids[0]], ), - ( - "feedback[wandb.runnable.my_scorer].payload.expected", - "desc", - [ids[0], ids[2], ids[1], ids[3]], - ), ( "feedback[wandb.runnable.my_scorer].payload.match", - "asc", [ids[0], ids[2], ids[1], ids[3]], ), - ( - "feedback[wandb.runnable.my_scorer].payload.match", - "desc", - [ids[3], ids[1], ids[2], ids[0]], - ), ]: calls = client.server.calls_query_stream( tsi.CallsQueryReq( @@ -397,16 +379,84 @@ def test_sort_by_feedback(client: WeaveClient) -> None: filter=tsi.CallsFilter(op_name=[get_ref(my_model).uri()]), sort_by=tsi.SortBy( field=field, - direction=direction, + direction="asc", ), include_feedback=True, ) ) - assert [c.id for c in calls] == id_exp + assert [c.id for c in calls] == asc_ids + + calls = client.server.calls_query_stream( + tsi.CallsQueryReq( + project_id=client._project_id(), + filter=tsi.CallsFilter(op_name=[get_ref(my_model).uri()]), + sort_by=tsi.SortBy( + field=field, + direction="desc", + ), + include_feedback=True, + ) + ) + + assert [c.id for c in calls] == desc_ids def test_filter_by_feedback(client: WeaveClient) -> None: """Test filtering by feedback.""" ids, my_scorer, my_model = populate_feedback(client) - raise NotImplementedError + for field, value, eq_ids, gt_ids in [ + ( + "feedback[wandb.runnable.my_scorer].payload.model_output", + "a", + [ids[0], ids[2]], + [ids[1], ids[3]], + ), + ( + "feedback[wandb.runnable.my_scorer].payload.expected", + "a", + [ids[3], ids[1]], + [ids[0], ids[2]], + ), + ( + "feedback[wandb.runnable.my_scorer].payload.match", + True, + [ids[0], ids[2]], + [ids[1], ids[3]], + ), + ]: + calls = client.server.calls_query_stream( + tsi.CallsQueryReq( + project_id=client._project_id(), + filter=tsi.CallsFilter(op_name=[get_ref(my_model).uri()]), + query={ + "$expr": { + "$eq": [ + {"$getField": field}, + {"$literal": value}, + ] + } + }, + include_feedback=True, + ) + ) + + assert [c.id for c in calls] == eq_ids + + calls = client.server.calls_query_stream( + tsi.CallsQueryReq( + project_id=client._project_id(), + filter=tsi.CallsFilter(op_name=[get_ref(my_model).uri()]), + query={ + "$expr": { + "$gt": [ + {"$getField": field}, + {"$literal": value}, + ] + } + }, + include_feedback=True, + ) + ) + + assert [c.id for c in calls] == gt_ids From 00fd58768f32d7f2fbf92e0cbacbc2fbd4e30d85 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 17:44:23 -0800 Subject: [PATCH 008/120] Initial sort implementation --- tests/trace/test_feedback.py | 71 +++++++++------ weave/trace_server/calls_query_builder.py | 101 ++++++++++++++++++---- 2 files changed, 128 insertions(+), 44 deletions(-) diff --git a/tests/trace/test_feedback.py b/tests/trace/test_feedback.py index 21fba1528fa..c8f56c766f7 100644 --- a/tests/trace/test_feedback.py +++ b/tests/trace/test_feedback.py @@ -339,7 +339,7 @@ def my_model(x: int) -> str: return [ "a", "x", # intentional "mistake" - "b", + "c", "y", # intentional "mistake" ][x] @@ -359,47 +359,60 @@ def test_sort_by_feedback(client: WeaveClient) -> None: """Test sorting by feedback.""" ids, my_scorer, my_model = populate_feedback(client) - for field, asc_ids, desc_ids in [ + for fields, asc_ids in [ ( - "feedback[wandb.runnable.my_scorer].payload.model_output", + ["feedback.[wandb.runnable.my_scorer].payload.output.model_output"], [ids[0], ids[2], ids[1], ids[3]], ), ( - "feedback[wandb.runnable.my_scorer].payload.expected", - [ids[3], ids[1], ids[2], ids[0]], + ["feedback.[wandb.runnable.my_scorer].payload.output.expected"], + [ids[0], ids[1], ids[2], ids[3]], ), ( - "feedback[wandb.runnable.my_scorer].payload.match", - [ids[0], ids[2], ids[1], ids[3]], + [ + "feedback.[wandb.runnable.my_scorer].payload.output.match", + "feedback.[wandb.runnable.my_scorer].payload.output.model_output", + ], + [ids[1], ids[3], ids[0], ids[2]], ), ]: calls = client.server.calls_query_stream( tsi.CallsQueryReq( project_id=client._project_id(), - filter=tsi.CallsFilter(op_name=[get_ref(my_model).uri()]), - sort_by=tsi.SortBy( - field=field, - direction="asc", - ), - include_feedback=True, + filter=tsi.CallsFilter(op_names=[get_ref(my_model).uri()]), + sort_by=[ + tsi.SortBy( + field=field, + direction="asc", + ) + for field in fields + ], ) ) - assert [c.id for c in calls] == asc_ids + found_ids = [c.id for c in calls] + assert ( + found_ids == asc_ids + ), f"Sorting by {fields} ascending failed, expected {asc_ids}, got {found_ids}" calls = client.server.calls_query_stream( tsi.CallsQueryReq( project_id=client._project_id(), - filter=tsi.CallsFilter(op_name=[get_ref(my_model).uri()]), - sort_by=tsi.SortBy( - field=field, - direction="desc", - ), - include_feedback=True, + filter=tsi.CallsFilter(op_names=[get_ref(my_model).uri()]), + sort_by=[ + tsi.SortBy( + field=field, + direction="desc", + ) + for field in fields + ], ) ) - assert [c.id for c in calls] == desc_ids + found_ids = [c.id for c in calls] + assert ( + found_ids == asc_ids[::-1] + ), f"Sorting by {fields} descending failed, expected {asc_ids[::-1]}, got {found_ids}" def test_filter_by_feedback(client: WeaveClient) -> None: @@ -407,19 +420,19 @@ def test_filter_by_feedback(client: WeaveClient) -> None: ids, my_scorer, my_model = populate_feedback(client) for field, value, eq_ids, gt_ids in [ ( - "feedback[wandb.runnable.my_scorer].payload.model_output", + "feedback.[wandb.runnable.my_scorer].payload.model_output", "a", [ids[0], ids[2]], [ids[1], ids[3]], ), ( - "feedback[wandb.runnable.my_scorer].payload.expected", + "feedback.[wandb.runnable.my_scorer].payload.expected", "a", [ids[3], ids[1]], [ids[0], ids[2]], ), ( - "feedback[wandb.runnable.my_scorer].payload.match", + "feedback.[wandb.runnable.my_scorer].payload.match", True, [ids[0], ids[2]], [ids[1], ids[3]], @@ -428,7 +441,7 @@ def test_filter_by_feedback(client: WeaveClient) -> None: calls = client.server.calls_query_stream( tsi.CallsQueryReq( project_id=client._project_id(), - filter=tsi.CallsFilter(op_name=[get_ref(my_model).uri()]), + filter=tsi.CallsFilter(op_names=[get_ref(my_model).uri()]), query={ "$expr": { "$eq": [ @@ -441,12 +454,13 @@ def test_filter_by_feedback(client: WeaveClient) -> None: ) ) - assert [c.id for c in calls] == eq_ids + found_ids = [c.id for c in calls] + assert found_ids == eq_ids calls = client.server.calls_query_stream( tsi.CallsQueryReq( project_id=client._project_id(), - filter=tsi.CallsFilter(op_name=[get_ref(my_model).uri()]), + filter=tsi.CallsFilter(op_names=[get_ref(my_model).uri()]), query={ "$expr": { "$gt": [ @@ -459,4 +473,5 @@ def test_filter_by_feedback(client: WeaveClient) -> None: ) ) - assert [c.id for c in calls] == gt_ids + found_ids = [c.id for c in calls] + assert found_ids == gt_ids diff --git a/weave/trace_server/calls_query_builder.py b/weave/trace_server/calls_query_builder.py index 5ba0bd686c2..9103755e2b4 100644 --- a/weave/trace_server/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder.py @@ -23,11 +23,11 @@ * [ ] Implement column selection at interface level so that it can be used here * [ ] Consider how we will do latency order/filter -* [ ] Consider how we will do feedback fields """ import logging +import re import typing import sqlparse @@ -114,6 +114,48 @@ def is_heavy(self) -> bool: return True +class CallsMergedFeedbackPayloadField(CallsMergedField): + feedback_type: str + extra_path: list[str] + + @classmethod + def from_path(cls, path: str) -> "CallsMergedFeedbackPayloadField": + """Expected format: `[feedback.type].dot.path`""" + regex = re.compile(r"^(\[.+\])\.(.+\..+)$") + match = regex.match(path) + if not match: + raise InvalidFieldError(f"Invalid feedback path: {path}") + feedback_type, path = match.groups() + if feedback_type[0] != "[" or feedback_type[-1] != "]": + raise InvalidFieldError(f"Invalid feedback type: {feedback_type}") + extra_path = path.split(".") + if extra_path[0] != "payload": + raise InvalidFieldError(f"Invalid feedback path: {path}") + feedback_type = feedback_type[1:-1] + return CallsMergedFeedbackPayloadField( + field="payload_dump", feedback_type=feedback_type, extra_path=extra_path[1:] + ) + + def is_heavy(self) -> bool: + return True + + def as_sql( + self, + pb: ParamBuilder, + table_alias: str, + cast: typing.Optional[tsi_query.CastTo] = None, + ) -> str: + inner = super().as_sql(pb, "feedback") + param_name = pb.add_param(self.feedback_type) + res = f"anyIf({inner}, feedback.feedback_type = {_param_slot(param_name, 'String')})" + return json_dump_field_as_sql(pb, "feedback", res, self.extra_path, cast) + + def as_select_sql(self, pb: ParamBuilder, table_alias: str) -> str: + raise NotImplementedError( + "Feedback fields cannot be selected directly, yet - implement me!" + ) + + class QueryBuilderDynamicField(QueryBuilderField): # This is a temporary solution to address a specific use case. # We need to reuse the `CallsMergedDynamicField` mechanics in the table_query, @@ -179,7 +221,14 @@ class OrderField(BaseModel): def as_sql(self, pb: ParamBuilder, table_alias: str) -> str: options: list[typing.Tuple[typing.Optional[tsi_query.CastTo], str]] - if isinstance(self.field, (QueryBuilderDynamicField, CallsMergedDynamicField)): + if isinstance( + self.field, + ( + QueryBuilderDynamicField, + CallsMergedDynamicField, + CallsMergedFeedbackPayloadField, + ), + ): # Prioritize existence, then cast to double, then str options = [ ("exists", "desc"), @@ -519,6 +568,7 @@ def _as_sql_base_format( table_alias: str, id_subquery_name: typing.Optional[str] = None, ) -> str: + needs_feedback = False select_fields_sql = ", ".join( field.as_select_sql(pb, table_alias) for field in self.select_fields ) @@ -545,6 +595,9 @@ def _as_sql_base_format( for order_field in self.order_fields ] ) + for order_field in self.order_fields: + if isinstance(order_field.field, CallsMergedFeedbackPayloadField): + needs_feedback = True limit_sql = "" if self.limit is not None: @@ -556,23 +609,36 @@ def _as_sql_base_format( id_subquery_sql = "" if id_subquery_name is not None: - id_subquery_sql = f"AND (id IN {id_subquery_name})" + id_subquery_sql = f"AND (calls_merged.id IN {id_subquery_name})" project_param = pb.add_param(self.project_id) # Special Optimization id_mask_sql = "" if self.hardcoded_filter and self.hardcoded_filter.filter.call_ids: - id_mask_sql = f"AND (id IN {_param_slot(pb.add_param(self.hardcoded_filter.filter.call_ids), 'Array(String)')})" + id_mask_sql = f"AND (calls_merged.id IN {_param_slot(pb.add_param(self.hardcoded_filter.filter.call_ids), 'Array(String)')})" # TODO: We should also pull out id-masks from the dynamic query + feedback_join_sql = "" + feedback_where_sql = "" + if needs_feedback: + feedback_where_sql = ( + f" AND calls_merged.project_id = {_param_slot(project_param, 'String')}" + ) + feedback_join_sql = f""" + LEFT JOIN feedback + ON (feedback.weave_ref = concat('weave-trace-internal:///', {_param_slot(project_param, 'String')}, '/call/', calls_merged.id)) + """ + raw_sql = f""" SELECT {select_fields_sql} FROM calls_merged - WHERE project_id = {_param_slot(project_param, 'String')} + {feedback_join_sql} + WHERE calls_merged.project_id = {_param_slot(project_param, 'String')} + {feedback_where_sql} {id_mask_sql} {id_subquery_sql} - GROUP BY (project_id, id) + GROUP BY (calls_merged.project_id, calls_merged.id) {having_filter_sql} {order_by_sql} {limit_sql} @@ -606,16 +672,19 @@ def _as_sql_base_format( def get_field_by_name(name: str) -> CallsMergedField: if name not in ALLOWED_CALL_FIELDS: - field_parts = name.split(".") - start_part = field_parts[0] - dumped_start_part = start_part + "_dump" - if dumped_start_part in ALLOWED_CALL_FIELDS: - field = ALLOWED_CALL_FIELDS[dumped_start_part] - if isinstance(field, CallsMergedDynamicField): - if len(field_parts) > 1: - return field.with_path(field_parts[1:]) - return field - raise InvalidFieldError(f"Field {name} is not allowed") + if name.startswith("feedback."): + return CallsMergedFeedbackPayloadField.from_path(name[len("feedback.") :]) + else: + field_parts = name.split(".") + start_part = field_parts[0] + dumped_start_part = start_part + "_dump" + if dumped_start_part in ALLOWED_CALL_FIELDS: + field = ALLOWED_CALL_FIELDS[dumped_start_part] + if isinstance(field, CallsMergedDynamicField): + if len(field_parts) > 1: + return field.with_path(field_parts[1:]) + return field + raise InvalidFieldError(f"Field {name} is not allowed") return ALLOWED_CALL_FIELDS[name] From 29ef40ba3bcbb7ff8f145ce59c8499daf72e7818 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 17:52:54 -0800 Subject: [PATCH 009/120] Other Sort Tests --- .../trace_server/test_calls_query_builder.py | 121 ++++++++++++++++++ weave/trace_server/orm.py | 10 ++ 2 files changed, 131 insertions(+) diff --git a/tests/trace_server/test_calls_query_builder.py b/tests/trace_server/test_calls_query_builder.py index 23716a13a68..ce8b50dd14a 100644 --- a/tests/trace_server/test_calls_query_builder.py +++ b/tests/trace_server/test_calls_query_builder.py @@ -444,3 +444,124 @@ def test_query_light_column_with_costs() -> None: "pb_3": 1, }, ) + + +def test_query_with_simple_feedback_sort() -> None: + cq = CallsQuery(project_id="project") + cq.add_field("id") + cq.add_order("feedback.[wandb.runnable.my_op].payload.output.expected", "desc") + assert_sql( + cq, + """ + SELECT + calls_merged.id AS id + FROM + calls_merged + LEFT JOIN feedback ON + (feedback.weave_ref = concat('weave-trace-internal:///', + {pb_4:String}, + '/call/', + calls_merged.id)) + WHERE + calls_merged.project_id = {pb_4:String} + AND calls_merged.project_id = {pb_4:String} + GROUP BY + (calls_merged.project_id, + calls_merged.id) + HAVING + (((any(calls_merged.deleted_at) IS NULL)) + AND ((NOT ((any(calls_merged.started_at) IS NULL))))) + ORDER BY + (NOT (JSONType(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_0:String}), + {pb_1:String}, + {pb_2:String}) = 'Null' + OR JSONType(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_0:String}), + {pb_1:String}, + {pb_2:String}) IS NULL)) desc, + toFloat64OrNull(JSON_VALUE(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_0:String}), + {pb_3:String})) DESC, + toString(JSON_VALUE(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_0:String}), + {pb_3:String})) DESC + """, + { + "pb_0": "wandb.runnable.my_op", + "pb_1": "output", + "pb_2": "expected", + "pb_3": '$."output"."expected"', + "pb_4": "project", + }, + ) + + +def test_query_with_simple_feedback_sort_with_op_name() -> None: + cq = CallsQuery(project_id="project") + cq.add_field("id") + cq.set_hardcoded_filter( + HardCodedFilter( + filter={"op_names": ["weave-trace-internal:///project/op/my_op:1234567890"]} + ) + ) + cq.add_order("feedback.[wandb.runnable.my_op].payload.output.expected", "desc") + assert_sql( + cq, + """ + WITH filtered_calls AS + ( + SELECT + calls_merged.id AS id + FROM + calls_merged + WHERE + calls_merged.project_id = {pb_1:String} + GROUP BY + (calls_merged.project_id, + calls_merged.id) + HAVING + (((any(calls_merged.deleted_at) IS NULL)) + AND ((NOT ((any(calls_merged.started_at) IS NULL)))) + AND (any(calls_merged.op_name) IN {pb_0:Array(String)}))) + SELECT + calls_merged.id AS id + FROM + calls_merged + LEFT JOIN feedback ON + (feedback.weave_ref = concat('weave-trace-internal:///', + {pb_1:String}, + '/call/', + calls_merged.id)) + WHERE + calls_merged.project_id = {pb_1:String} + AND calls_merged.project_id = {pb_1:String} + AND (calls_merged.id IN filtered_calls) + GROUP BY + (calls_merged.project_id, + calls_merged.id) + ORDER BY + (NOT (JSONType(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_2:String}), + {pb_3:String}, + {pb_4:String}) = 'Null' + OR JSONType(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_2:String}), + {pb_3:String}, + {pb_4:String}) IS NULL)) desc, + toFloat64OrNull(JSON_VALUE(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_2:String}), + {pb_5:String})) DESC, + toString(JSON_VALUE(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_2:String}), + {pb_5:String})) DESC + """, + { + "pb_0": ["weave-trace-internal:///project/op/my_op:1234567890"], + "pb_1": "project", + "pb_2": "wandb.runnable.my_op", + "pb_3": "output", + "pb_4": "expected", + "pb_5": '$."output"."expected"', + }, + ) diff --git a/weave/trace_server/orm.py b/weave/trace_server/orm.py index 6f02f0c2644..e7bd2843e53 100644 --- a/weave/trace_server/orm.py +++ b/weave/trace_server/orm.py @@ -48,10 +48,20 @@ def __init__( self._params: typing.Dict[str, typing.Any] = {} self._prefix = (prefix or f"pb_{param_builder_count}") + "_" self._database_type = database_type + self._param_to_name: dict[typing.Any, str] = {} def add_param(self, param_value: typing.Any) -> str: + try: + if param_value in self._param_to_name: + return self._param_to_name[param_value] + except TypeError: + pass param_name = self._prefix + str(len(self._params)) self._params[param_name] = param_value + try: + self._param_to_name[param_value] = param_name + except TypeError: + pass return param_name def add( From 2af8f155e472b8e882880af1bb8d3428bded6a44 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 18:58:17 -0800 Subject: [PATCH 010/120] Initial filter tests --- tests/trace/test_feedback.py | 30 ++++++++++++----------- weave/trace_server/calls_query_builder.py | 17 +++++++++---- weave/trace_server/orm.py | 4 +-- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/tests/trace/test_feedback.py b/tests/trace/test_feedback.py index c8f56c766f7..44a228bda1f 100644 --- a/tests/trace/test_feedback.py +++ b/tests/trace/test_feedback.py @@ -420,22 +420,22 @@ def test_filter_by_feedback(client: WeaveClient) -> None: ids, my_scorer, my_model = populate_feedback(client) for field, value, eq_ids, gt_ids in [ ( - "feedback.[wandb.runnable.my_scorer].payload.model_output", + "feedback.[wandb.runnable.my_scorer].payload.output.model_output", "a", - [ids[0], ids[2]], - [ids[1], ids[3]], + [ids[0]], + [ids[1], ids[2], ids[3]], ), ( - "feedback.[wandb.runnable.my_scorer].payload.expected", - "a", - [ids[3], ids[1]], - [ids[0], ids[2]], + "feedback.[wandb.runnable.my_scorer].payload.output.expected", + "c", + [ids[2]], + [ids[3]], ), ( - "feedback.[wandb.runnable.my_scorer].payload.match", - True, - [ids[0], ids[2]], + "feedback.[wandb.runnable.my_scorer].payload.output.match", + "false", [ids[1], ids[3]], + [ids[0], ids[2]], ), ]: calls = client.server.calls_query_stream( @@ -450,12 +450,13 @@ def test_filter_by_feedback(client: WeaveClient) -> None: ] } }, - include_feedback=True, ) ) found_ids = [c.id for c in calls] - assert found_ids == eq_ids + assert ( + found_ids == eq_ids + ), f"Filtering by {field} == {value} failed, expected {eq_ids}, got {found_ids}" calls = client.server.calls_query_stream( tsi.CallsQueryReq( @@ -469,9 +470,10 @@ def test_filter_by_feedback(client: WeaveClient) -> None: ] } }, - include_feedback=True, ) ) found_ids = [c.id for c in calls] - assert found_ids == gt_ids + assert ( + found_ids == gt_ids + ), f"Filtering by {field} > {value} failed, expected {gt_ids}, got {found_ids}" diff --git a/weave/trace_server/calls_query_builder.py b/weave/trace_server/calls_query_builder.py index 9103755e2b4..a5d0e6b361f 100644 --- a/weave/trace_server/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder.py @@ -258,7 +258,8 @@ def as_sql(self, pb: ParamBuilder, table_alias: str) -> str: if self._consumed_fields is None: self._consumed_fields = [] for field in conditions.fields_used: - self._consumed_fields.append(get_field_by_name(field)) + # TODO: Verify that this is "ok" since before we were just looking at field name + self._consumed_fields.append(field) return combine_conditions(conditions.conditions, "AND") def _get_consumed_fields(self) -> list[CallsMergedField]: @@ -579,6 +580,10 @@ def _as_sql_base_format( having_conditions_sql.extend( c.as_sql(pb, table_alias) for c in self.query_conditions ) + for query_condition in self.query_conditions: + for field in query_condition._get_consumed_fields(): + if isinstance(field, CallsMergedFeedbackPayloadField): + needs_feedback = True if self.hardcoded_filter is not None: having_conditions_sql.append(self.hardcoded_filter.as_sql(pb, table_alias)) @@ -690,7 +695,7 @@ def get_field_by_name(name: str) -> CallsMergedField: class FilterToConditions(BaseModel): conditions: list[str] - fields_used: set[str] + fields_used: list[CallsMergedField] def process_query_to_conditions( @@ -700,7 +705,7 @@ def process_query_to_conditions( ) -> FilterToConditions: """Converts a Query to a list of conditions for a clickhouse query.""" conditions = [] - raw_fields_used = set() + raw_fields_used: dict[str, CallsMergedField] = {} # This is the mongo-style query def process_operation(operation: tsi_query.Operation) -> str: @@ -766,7 +771,7 @@ def process_operand(operand: "tsi_query.Operand") -> str: elif isinstance(operand, tsi_query.GetFieldOperator): structured_field = get_field_by_name(operand.get_field_) field = structured_field.as_sql(param_builder, table_alias) - raw_fields_used.add(structured_field.field) + raw_fields_used[structured_field.field] = structured_field return field elif isinstance(operand, tsi_query.ConvertOperation): field = process_operand(operand.convert_.input) @@ -792,7 +797,9 @@ def process_operand(operand: "tsi_query.Operand") -> str: conditions.append(filter_cond) - return FilterToConditions(conditions=conditions, fields_used=raw_fields_used) + return FilterToConditions( + conditions=conditions, fields_used=list(raw_fields_used.values()) + ) def process_calls_filter_to_conditions( diff --git a/weave/trace_server/orm.py b/weave/trace_server/orm.py index e7bd2843e53..3b5f0345c5c 100644 --- a/weave/trace_server/orm.py +++ b/weave/trace_server/orm.py @@ -484,12 +484,12 @@ def python_value_to_ch_type(value: typing.Any) -> str: """Helper function to convert python types to clickhouse types.""" if isinstance(value, str): return "String" + elif isinstance(value, bool): + return "Bool" elif isinstance(value, int): return "UInt64" elif isinstance(value, float): return "Float64" - elif isinstance(value, bool): - return "UInt8" elif value is None: return "Nullable(String)" else: From 643ea0ec3d1d507521a5b2fc5668eedfbabc7304 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 19:15:27 -0800 Subject: [PATCH 011/120] Finished filter tests --- .../trace_server/test_calls_query_builder.py | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/tests/trace_server/test_calls_query_builder.py b/tests/trace_server/test_calls_query_builder.py index ce8b50dd14a..8a565f69216 100644 --- a/tests/trace_server/test_calls_query_builder.py +++ b/tests/trace_server/test_calls_query_builder.py @@ -565,3 +565,124 @@ def test_query_with_simple_feedback_sort_with_op_name() -> None: "pb_5": '$."output"."expected"', }, ) + + +def test_query_with_simple_feedback_filter() -> None: + cq = CallsQuery(project_id="project") + cq.add_field("id") + cq.add_condition( + tsi_query.GtOperation.model_validate( + { + "$gt": [ + { + "$getField": "feedback.[wandb.runnable.my_op].payload.output.expected" + }, + { + "$getField": "feedback.[wandb.runnable.my_op].payload.output.found" + }, + ] + } + ) + ) + assert_sql( + cq, + """ + SELECT + calls_merged.id AS id + FROM + calls_merged + LEFT JOIN feedback ON + (feedback.weave_ref = concat('weave-trace-internal:///', + {pb_3:String}, + '/call/', + calls_merged.id)) + WHERE + calls_merged.project_id = {pb_3:String} + AND calls_merged.project_id = {pb_3:String} + GROUP BY + (calls_merged.project_id, + calls_merged.id) + HAVING + (((JSON_VALUE(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_0:String}), + {pb_1:String}) > JSON_VALUE(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_0:String}), + {pb_2:String}))) + AND ((any(calls_merged.deleted_at) IS NULL)) + AND ((NOT ((any(calls_merged.started_at) IS NULL))))) + """, + { + "pb_0": "wandb.runnable.my_op", + "pb_1": '$."output"."expected"', + "pb_2": '$."output"."found"', + "pb_3": "project", + }, + ) + + +def test_query_with_simple_feedback_sort_and_filter() -> None: + cq = CallsQuery(project_id="project") + cq.add_field("id") + cq.add_condition( + tsi_query.EqOperation.model_validate( + { + "$eq": [ + { + "$getField": "feedback.[wandb.runnable.my_op].payload.output.expected" + }, + {"$literal": "a"}, + ] + } + ) + ) + cq.add_order("feedback.[wandb.runnable.my_op].payload.output.score", "desc") + assert_sql( + cq, + """ + SELECT + calls_merged.id AS id + FROM + calls_merged + LEFT JOIN feedback ON + (feedback.weave_ref = concat('weave-trace-internal:///', + {pb_6:String}, + '/call/', + calls_merged.id)) + WHERE + calls_merged.project_id = {pb_6:String} + AND calls_merged.project_id = {pb_6:String} + GROUP BY + (calls_merged.project_id, + calls_merged.id) + HAVING + (((JSON_VALUE(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_0:String}), + {pb_1:String}) = {pb_2:String})) + AND ((any(calls_merged.deleted_at) IS NULL)) + AND ((NOT ((any(calls_merged.started_at) IS NULL))))) + ORDER BY + (NOT (JSONType(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_0:String}), + {pb_3:String}, + {pb_4:String}) = 'Null' + OR JSONType(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_0:String}), + {pb_3:String}, + {pb_4:String}) IS NULL)) desc, + toFloat64OrNull(JSON_VALUE(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_0:String}), + {pb_5:String})) DESC, + toString(JSON_VALUE(anyIf(feedback.payload_dump, + feedback.feedback_type = {pb_0:String}), + {pb_5:String})) DESC + """, + { + "pb_0": "wandb.runnable.my_op", + "pb_1": '$."output"."expected"', + "pb_2": "a", + "pb_3": "output", + "pb_4": "score", + "pb_5": '$."output"."score"', + "pb_6": "project", + }, + ) From 79e83fe6d7423bee30d5afdcad1cc92eb06580fb Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 19:27:51 -0800 Subject: [PATCH 012/120] Test fix --- tests/trace/test_scores.py | 10 ++++------ weave/trace_server/clickhouse_trace_server_batched.py | 4 ++++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/trace/test_scores.py b/tests/trace/test_scores.py index 2ccea93c237..bd44c072453 100644 --- a/tests/trace/test_scores.py +++ b/tests/trace/test_scores.py @@ -1,7 +1,6 @@ from concurrent.futures import Future import weave -from weave.trace.feedback_types.score import SCORE_TYPE_NAME from weave.trace.weave_client import get_ref from weave.trace_server import trace_server_interface as tsi @@ -39,9 +38,8 @@ def my_score(input_x: int, model_output: int) -> int: assert len(calls) == 2 feedback = calls[0].summary["weave"]["feedback"][0] - assert feedback["feedback_type"] == SCORE_TYPE_NAME + assert feedback["feedback_type"] == "wandb.runnable.my_score" assert feedback["weave_ref"] == get_ref(call).uri() - assert feedback["payload"]["name"] == "my_score" - assert feedback["payload"]["op_ref"] == get_ref(my_score).uri() - assert feedback["payload"]["call_ref"] == get_ref(score_call).uri() - assert feedback["payload"]["results"] == score_res + assert feedback["runnable_ref"] == get_ref(my_score).uri() + assert feedback["call_ref"] == get_ref(score_call).uri() + assert feedback["payload"] == {"output": score_res} diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 8d8f57aff65..ba3db28cc49 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -1361,6 +1361,10 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: "feedback_type": req.feedback_type, "payload": req.payload, "created_at": created_at, + "annotation_ref": req.annotation_ref, + "runnable_ref": req.runnable_ref, + "call_ref": req.call_ref, + "trigger_ref": req.trigger_ref, } prepared = TABLE_FEEDBACK.insert(row).prepare(database_type="clickhouse") self._insert(TABLE_FEEDBACK.name, prepared.data, prepared.column_names) From b4963d5cfe20d3c44657a9374b32f67860d274ab Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 19:50:23 -0800 Subject: [PATCH 013/120] Fixed sqlite tests --- tests/trace/test_feedback.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/tests/trace/test_feedback.py b/tests/trace/test_feedback.py index 44a228bda1f..a5a3c412f40 100644 --- a/tests/trace/test_feedback.py +++ b/tests/trace/test_feedback.py @@ -1,6 +1,9 @@ +import datetime + import pytest import weave +from tests.trace.util import client_is_sqlite from weave.trace.weave_client import WeaveClient, get_ref from weave.trace_server import trace_server_interface as tsi from weave.trace_server.errors import InvalidRequest @@ -169,7 +172,10 @@ def test_annotation_feedback(client: WeaveClient) -> None: "weave_ref": weave_ref, "wb_user_id": "shawn", "creator": None, - "created_at": create_res.created_at.isoformat().replace("T", " "), + # Sad - seems like sqlite and clickhouse remote different types here + "created_at": create_res.created_at.isoformat().replace("T", " ") + if client_is_sqlite(client) + else MatchAnyDatetime(), "feedback_type": feedback_type, "payload": payload, "annotation_ref": annotation_ref, @@ -314,7 +320,10 @@ def test_runnable_feedback(client: WeaveClient) -> None: "weave_ref": weave_ref, "wb_user_id": "shawn", "creator": None, - "created_at": create_res.created_at.isoformat().replace("T", " "), + # Sad - seems like sqlite and clickhouse remote different types here + "created_at": create_res.created_at.isoformat().replace("T", " ") + if client_is_sqlite(client) + else MatchAnyDatetime(), "feedback_type": feedback_type, "payload": payload, "annotation_ref": None, @@ -356,6 +365,10 @@ def my_model(x: int) -> str: def test_sort_by_feedback(client: WeaveClient) -> None: + if client_is_sqlite(client): + # Not implemented in sqlite - skip + return pytest.skip() + """Test sorting by feedback.""" ids, my_scorer, my_model = populate_feedback(client) @@ -416,6 +429,10 @@ def test_sort_by_feedback(client: WeaveClient) -> None: def test_filter_by_feedback(client: WeaveClient) -> None: + if client_is_sqlite(client): + # Not implemented in sqlite - skip + return pytest.skip() + """Test filtering by feedback.""" ids, my_scorer, my_model = populate_feedback(client) for field, value, eq_ids, gt_ids in [ @@ -477,3 +494,8 @@ def test_filter_by_feedback(client: WeaveClient) -> None: assert ( found_ids == gt_ids ), f"Filtering by {field} > {value} failed, expected {gt_ids}, got {found_ids}" + + +class MatchAnyDatetime: + def __eq__(self, other): + return isinstance(other, datetime.datetime) From 22e29827c5d3449f077e8022da6ad3a2f67e570e Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 20:01:12 -0800 Subject: [PATCH 014/120] Fixed sqlite tests 2 --- .../trace_server/test_calls_query_builder.py | 74 +++++++++---------- 1 file changed, 36 insertions(+), 38 deletions(-) diff --git a/tests/trace_server/test_calls_query_builder.py b/tests/trace_server/test_calls_query_builder.py index 8a565f69216..6ed62b9c615 100644 --- a/tests/trace_server/test_calls_query_builder.py +++ b/tests/trace_server/test_calls_query_builder.py @@ -14,8 +14,8 @@ def test_query_baseline() -> None: """ SELECT calls_merged.id AS id FROM calls_merged - WHERE project_id = {pb_0:String} - GROUP BY (project_id,id) + WHERE calls_merged.project_id = {pb_0:String} + GROUP BY (calls_merged.project_id, calls_merged.id) HAVING ( (( any(calls_merged.deleted_at) IS NULL @@ -43,8 +43,8 @@ def test_query_light_column() -> None: calls_merged.id AS id, any(calls_merged.started_at) AS started_at FROM calls_merged - WHERE project_id = {pb_0:String} - GROUP BY (project_id,id) + WHERE calls_merged.project_id = {pb_0:String} + GROUP BY (calls_merged.project_id, calls_merged.id) HAVING ( (( any(calls_merged.deleted_at) IS NULL @@ -72,8 +72,8 @@ def test_query_heavy_column() -> None: calls_merged.id AS id, any(calls_merged.inputs_dump) AS inputs_dump FROM calls_merged - WHERE project_id = {pb_0:String} - GROUP BY (project_id,id) + WHERE calls_merged.project_id = {pb_0:String} + GROUP BY (calls_merged.project_id, calls_merged.id) HAVING ( (( any(calls_merged.deleted_at) IS NULL @@ -108,8 +108,8 @@ def test_query_heavy_column_simple_filter() -> None: SELECT calls_merged.id AS id FROM calls_merged - WHERE project_id = {pb_1:String} - GROUP BY (project_id,id) + WHERE calls_merged.project_id = {pb_1:String} + GROUP BY (calls_merged.project_id, calls_merged.id) HAVING ( ((any(calls_merged.deleted_at) IS NULL)) AND ((NOT ((any(calls_merged.started_at) IS NULL)))) @@ -121,12 +121,12 @@ def test_query_heavy_column_simple_filter() -> None: any(calls_merged.inputs_dump) AS inputs_dump FROM calls_merged WHERE - project_id = {pb_2:String} + calls_merged.project_id = {pb_1:String} AND - (id IN filtered_calls) - GROUP BY (project_id,id) + (calls_merged.id IN filtered_calls) + GROUP BY (calls_merged.project_id, calls_merged.id) """, - {"pb_0": ["a", "b"], "pb_1": "project", "pb_2": "project"}, + {"pb_0": ["a", "b"], "pb_1": "project"}, ) @@ -149,8 +149,8 @@ def test_query_heavy_column_simple_filter_with_order() -> None: SELECT calls_merged.id AS id FROM calls_merged - WHERE project_id = {pb_1:String} - GROUP BY (project_id,id) + WHERE calls_merged.project_id = {pb_1:String} + GROUP BY (calls_merged.project_id, calls_merged.id) HAVING ( ((any(calls_merged.deleted_at) IS NULL)) AND ((NOT ((any(calls_merged.started_at) IS NULL)))) @@ -162,13 +162,13 @@ def test_query_heavy_column_simple_filter_with_order() -> None: any(calls_merged.inputs_dump) AS inputs_dump FROM calls_merged WHERE - project_id = {pb_2:String} + calls_merged.project_id = {pb_1:String} AND - (id IN filtered_calls) - GROUP BY (project_id,id) + (calls_merged.id IN filtered_calls) + GROUP BY (calls_merged.project_id, calls_merged.id) ORDER BY any(calls_merged.started_at) DESC """, - {"pb_0": ["a", "b"], "pb_1": "project", "pb_2": "project"}, + {"pb_0": ["a", "b"], "pb_1": "project"}, ) @@ -192,8 +192,8 @@ def test_query_heavy_column_simple_filter_with_order_and_limit() -> None: SELECT calls_merged.id AS id FROM calls_merged - WHERE project_id = {pb_1:String} - GROUP BY (project_id,id) + WHERE calls_merged.project_id = {pb_1:String} + GROUP BY (calls_merged.project_id, calls_merged.id) HAVING ( ((any(calls_merged.deleted_at) IS NULL)) AND @@ -209,13 +209,13 @@ def test_query_heavy_column_simple_filter_with_order_and_limit() -> None: any(calls_merged.inputs_dump) AS inputs_dump FROM calls_merged WHERE - project_id = {pb_2:String} + calls_merged.project_id = {pb_1:String} AND - (id IN filtered_calls) - GROUP BY (project_id,id) + (calls_merged.id IN filtered_calls) + GROUP BY (calls_merged.project_id, calls_merged.id) ORDER BY any(calls_merged.started_at) DESC """, - {"pb_0": ["a", "b"], "pb_1": "project", "pb_2": "project"}, + {"pb_0": ["a", "b"], "pb_1": "project"}, ) @@ -258,8 +258,8 @@ def test_query_heavy_column_simple_filter_with_order_and_limit_and_mixed_query_c SELECT calls_merged.id AS id FROM calls_merged - WHERE project_id = {pb_2:String} - GROUP BY (project_id,id) + WHERE calls_merged.project_id = {pb_2:String} + GROUP BY (calls_merged.project_id, calls_merged.id) HAVING ( ((any(calls_merged.wb_user_id) = {pb_0:String})) AND @@ -275,10 +275,10 @@ def test_query_heavy_column_simple_filter_with_order_and_limit_and_mixed_query_c any(calls_merged.inputs_dump) AS inputs_dump FROM calls_merged WHERE - project_id = {pb_5:String} + calls_merged.project_id = {pb_2:String} AND - (id IN filtered_calls) - GROUP BY (project_id,id) + (calls_merged.id IN filtered_calls) + GROUP BY (calls_merged.project_id, calls_merged.id) HAVING ( JSON_VALUE(any(calls_merged.inputs_dump), {pb_3:String}) = {pb_4:String} ) @@ -291,7 +291,6 @@ def test_query_heavy_column_simple_filter_with_order_and_limit_and_mixed_query_c "pb_2": "project", "pb_3": '$."param"."val"', "pb_4": "hello", - "pb_5": "project", }, ) @@ -329,8 +328,8 @@ def test_query_light_column_with_costs() -> None: filtered_calls AS ( SELECT calls_merged.id AS id FROM calls_merged - WHERE project_id = {pb_1:String} - GROUP BY (project_id, id) + WHERE calls_merged.project_id = {pb_1:String} + GROUP BY (calls_merged.project_id, calls_merged.id) HAVING (((any(calls_merged.deleted_at) IS NULL)) AND ((NOT ((any(calls_merged.started_at) IS NULL)))) AND (any(calls_merged.op_name) IN {pb_0:Array(String)}))), @@ -339,9 +338,9 @@ def test_query_light_column_with_costs() -> None: calls_merged.id AS id, any(calls_merged.started_at) AS started_at FROM calls_merged - WHERE project_id = {pb_2:String} - AND (id IN filtered_calls) - GROUP BY (project_id, id)), + WHERE calls_merged.project_id = {pb_1:String} + AND (calls_merged.id IN filtered_calls) + GROUP BY (calls_merged.project_id, calls_merged.id)), -- From the all_calls we get the usage data for LLMs llm_usage AS ( SELECT @@ -434,14 +433,13 @@ def test_query_light_column_with_costs() -> None: '}' ) ) AS summary_dump FROM ranked_prices - WHERE (rank = {pb_3:UInt64}) + WHERE (rank = {pb_2:UInt64}) GROUP BY id, started_at """, { "pb_0": ["a", "b"], "pb_1": "UHJvamVjdEludGVybmFsSWQ6Mzk1NDg2Mjc=", - "pb_2": "UHJvamVjdEludGVybmFsSWQ6Mzk1NDg2Mjc=", - "pb_3": 1, + "pb_2": 1, }, ) From 2ac2557c03607f48e0aa3a0b504a1ee5445d80a1 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 4 Nov 2024 20:54:06 -0800 Subject: [PATCH 015/120] added one more test --- tests/trace/test_feedback.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/trace/test_feedback.py b/tests/trace/test_feedback.py index a5a3c412f40..264cb6e8c58 100644 --- a/tests/trace/test_feedback.py +++ b/tests/trace/test_feedback.py @@ -499,3 +499,39 @@ def test_filter_by_feedback(client: WeaveClient) -> None: class MatchAnyDatetime: def __eq__(self, other): return isinstance(other, datetime.datetime) + + +def test_filter_and_sort_by_feedback(client: WeaveClient) -> None: + if client_is_sqlite(client): + # Not implemented in sqlite - skip + return pytest.skip() + + """Test filtering by feedback.""" + ids, my_scorer, my_model = populate_feedback(client) + calls = client.server.calls_query_stream( + tsi.CallsQueryReq( + project_id=client._project_id(), + filter=tsi.CallsFilter(op_names=[get_ref(my_model).uri()]), + # Filter down to just correct matches + query={ + "$expr": { + "$eq": [ + { + "$getField": "feedback.[wandb.runnable.my_scorer].payload.output.match" + }, + {"$literal": "true"}, + ] + } + }, + # Sort by the model output desc + sort_by=[ + { + "field": "feedback.[wandb.runnable.my_scorer].payload.output.model_output", + "direction": "desc", + } + ], + ) + ) + calls = list(calls) + assert len(calls) == 2 + assert [c.id for c in calls] == [ids[2], ids[0]] From 63a37ce8fd3e8c50e15d7c55669d71b9f83d5c58 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 10:38:55 -0800 Subject: [PATCH 016/120] First incoming changes from Online Evals --- tests/trace/test_actions_e2e.py | 125 +++++++ weave/actions_worker/tasks.py | 252 +++++++++++++ .../clickhouse_trace_server_batched.py | 344 ++++++++++++++++-- ...ternal_to_internal_trace_server_adapter.py | 7 + .../base_models/action_base_models.py | 44 +++ weave/trace_server/sqlite_trace_server.py | 11 +- weave/trace_server/trace_server_interface.py | 32 +- .../remote_http_trace_server.py | 31 +- 8 files changed, 784 insertions(+), 62 deletions(-) create mode 100644 tests/trace/test_actions_e2e.py create mode 100644 weave/actions_worker/tasks.py create mode 100644 weave/trace_server/interface/base_models/action_base_models.py diff --git a/tests/trace/test_actions_e2e.py b/tests/trace/test_actions_e2e.py new file mode 100644 index 00000000000..548ec22ea94 --- /dev/null +++ b/tests/trace/test_actions_e2e.py @@ -0,0 +1,125 @@ +import pytest + +import weave +from weave.trace.refs import ObjectRef +from weave.trace.weave_client import WeaveClient +from weave.trace_server.interface.base_models.action_base_models import ( + ConfiguredAction, + ConfiguredContainsWordsAction, +) +from weave.trace_server.sqlite_trace_server import SqliteTraceServer +from weave.trace_server.trace_server_interface import ( + ActionsExecuteBatchReq, + FeedbackCreateReq, + ObjCreateReq, + ObjQueryReq, +) + + +def test_action_execute_workflow(client: WeaveClient): + is_sqlite = isinstance(client.server._internal_trace_server, SqliteTraceServer) # type: ignore + if is_sqlite: + # dont run this test for sqlite + return + + action_name = "test_action" + # part 1: create the action + digest = client.server.obj_create( + ObjCreateReq.model_validate( + { + "obj": { + "project_id": client._project_id(), + "object_id": action_name, + "base_object_class": "ConfiguredAction", + "val": ConfiguredAction( + name="test_action", + config=ConfiguredContainsWordsAction( + target_words=["mindful", "demure"] + ), + ).model_dump(), + } + } + ) + ).digest + + configured_actions = client.server.objs_query( + ObjQueryReq.model_validate( + { + "project_id": client._project_id(), + "filter": {"base_object_classes": ["ConfiguredAction"]}, + } + ) + ) + + assert len(configured_actions.objs) == 1 + assert configured_actions.objs[0].digest == digest + action_ref_uri = ObjectRef( + entity=client.entity, + project=client.project, + name=action_name, + _digest=digest, + ).uri() + + # part 2: manually create feedback + @weave.op + def example_op(input: str) -> str: + return input[::-1] + + _, call1 = example_op.call("i've been very mindful today") + with pytest.raises(Exception): + client.server.feedback_create( + FeedbackCreateReq.model_validate( + { + "project_id": client._project_id(), + "weave_ref": call1.ref.uri(), + "feedback_type": "MachineScore", + "payload": True, + } + ) + ) + + res = client.server.feedback_create( + FeedbackCreateReq.model_validate( + { + "project_id": client._project_id(), + "weave_ref": call1.ref.uri(), + "feedback_type": "MachineScore", + "payload": { + "runnable_ref": action_ref_uri, + "value": {action_name: {digest: True}}, + }, + } + ) + ) + + feedbacks = list(call1.feedback) + assert len(feedbacks) == 1 + assert feedbacks[0].payload == { + "runnable_ref": action_ref_uri, + "value": {action_name: {digest: True}}, + "call_ref": None, + "trigger_ref": None, + } + + # Step 3: test that we can in-place execute one action at a time. + + _, call2 = example_op.call("i've been very meditative today") + + res = client.server.actions_execute_batch( + ActionsExecuteBatchReq.model_validate( + { + "project_id": client._project_id(), + "call_ids": [call2.id], + "configured_action_ref": action_ref_uri, + } + ) + ) + + feedbacks = list(call2.feedback) + assert len(feedbacks) == 1 + assert feedbacks[0].payload == { + "runnable_ref": action_ref_uri, + "value": {action_name: {digest: False}}, + "call_ref": None, + "trigger_ref": None, + } diff --git a/weave/actions_worker/tasks.py b/weave/actions_worker/tasks.py new file mode 100644 index 00000000000..5e3b1f29f32 --- /dev/null +++ b/weave/actions_worker/tasks.py @@ -0,0 +1,252 @@ +import json +from functools import partial, wraps +from typing import Any, Callable, Optional, TypeVar + +from weave.actions_worker.celery_app import app +from weave.trace_server.action_executor import TaskCtx +from weave.trace_server.clickhouse_trace_server_batched import ( + ActionsAckBatchReq, + ClickHouseTraceServer, +) +from weave.trace_server.interface.base_models.action_base_models import ( + ConfiguredAction, + ConfiguredContainsWordsAction, + ConfiguredLlmJudgeAction, + ConfiguredNoopAction, + ConfiguredWordCountAction, +) +from weave.trace_server.interface.base_models.base_model_registry import base_model_name +from weave.trace_server.interface.base_models.feedback_base_model_registry import ( + MachineScore, +) +from weave.trace_server.refs_internal import ( + InternalCallRef, + InternalObjectRef, + InternalOpRef, + parse_internal_uri, +) +from weave.trace_server.trace_server_interface import ( + CallSchema, + CallsFilter, + CallsQueryReq, + FeedbackCreateReq, + RefsReadBatchReq, +) + +WEAVE_ACTION_EXECUTOR_PACEHOLDER_ID = "WEAVE_ACTION_EXECUTOR" + + +def ack_on_clickhouse(ctx: TaskCtx, succeeded: bool) -> None: + project_id = ctx["project_id"] + call_id = ctx["call_id"] + id = ctx["id"] + ClickHouseTraceServer.from_env().actions_ack_batch( + ActionsAckBatchReq( + project_id=project_id, call_ids=[call_id], id=id, succeeded=succeeded + ) + ) + + +def publish_results_as_feedback( + ctx: TaskCtx, + result: Any, + configured_action_ref: str, + trigger_ref: Optional[str] = None, +) -> None: + project_id = ctx["project_id"] + call_id = ctx["call_id"] + id = ctx["id"] + call_ref = InternalCallRef(project_id, call_id).uri() + parsed_action_ref = parse_internal_uri(configured_action_ref) + if not isinstance(parsed_action_ref, (InternalObjectRef, InternalOpRef)): + raise ValueError(f"Invalid action ref: {configured_action_ref}") + action_name = parsed_action_ref.name + digest = parsed_action_ref.version + ClickHouseTraceServer.from_env().feedback_create( + FeedbackCreateReq( + project_id=project_id, + weave_ref=call_ref, + feedback_type=base_model_name(MachineScore), + payload=MachineScore( + runnable_ref=configured_action_ref, + value={action_name: {digest: result}}, + trigger_ref=trigger_ref, + ).model_dump(), + wb_user_id=WEAVE_ACTION_EXECUTOR_PACEHOLDER_ID, + ) + ) + + +def resolve_action_ref(configured_action_ref: str) -> ConfiguredAction: + server = ClickHouseTraceServer.from_env() + action_dict_res = server.refs_read_batch( + RefsReadBatchReq(refs=[configured_action_ref]) + ) + action_dict = action_dict_res.vals[0] + assert isinstance(action_dict, dict) + action = ConfiguredAction.model_validate(action_dict) + return action + + +def resolve_call(ctx: TaskCtx) -> CallSchema: + project_id, call_id = ctx["project_id"], ctx["call_id"] + server = ClickHouseTraceServer.from_env() + calls_query_res = server.calls_query( + CallsQueryReq( + project_id=project_id, filter=CallsFilter(call_ids=[call_id]), limit=1 + ) + ) + return calls_query_res.calls[0] + + +ActionConfigT = TypeVar("ActionConfigT") +ActionResultT = TypeVar("ActionResultT") + + +def action_task( + func: Callable[[str, str, ActionConfigT], ActionResultT], +) -> Callable[[TaskCtx, str, str, str, ActionConfigT, Optional[str]], ActionResultT]: + @wraps(func) + def wrapper( + ctx: TaskCtx, + call_input: str, + call_output: str, + configured_action_ref: str, + configured_action: ActionConfigT, + trigger_ref: Optional[str] = None, + ) -> ActionResultT: + success = True + try: + result = func(call_input, call_output, configured_action) + publish_results_as_feedback(ctx, result, configured_action_ref, trigger_ref) + # logging.info(f"Successfully ran {func.__name__}") + # logging.info(f"Result: {result}") + except Exception as e: + success = False + raise e + finally: + ack_on_clickhouse(ctx, success) + return result + + return wrapper + + +@app.task() +def do_task( + ctx: TaskCtx, configured_action_ref: str, trigger_ref: Optional[str] = None +) -> None: + action = resolve_action_ref(configured_action_ref) + call = resolve_call(ctx) + call_input = json.dumps(call.inputs) + call_output = call.output + if not isinstance(call_output, str): + call_output = json.dumps(call_output) + + if action.config.action_type == "wordcount": + wordcount( + ctx, + call_input, + call_output, + configured_action_ref, + action.config, + trigger_ref, + ) + elif action.config.action_type == "llm_judge": + llm_judge( + ctx, + call_input, + call_output, + configured_action_ref, + action.config, + trigger_ref, + ) + elif action.config.action_type == "noop": + noop( + ctx, + call_input, + call_output, + configured_action_ref, + action.config, + trigger_ref, + ) + elif action.config.action_type == "contains_words": + contains_words( + ctx, + call_input, + call_output, + configured_action_ref, + action.config, + trigger_ref, + ) + else: + raise ValueError(f"Unknown action type: {action.config.action_type}") + + +@action_task +def wordcount( + call_input: str, call_output: str, config: ConfiguredWordCountAction +) -> int: + return len(call_output.split(" ")) + + +@action_task +def llm_judge( + call_input: str, call_output: str, config: ConfiguredLlmJudgeAction +) -> str: + model = config.model + system_prompt = config.prompt + if config.response_format is None: + raise ValueError("response_format is required for llm_judge") + + response_is_not_object = config.response_format["type"] != "object" + dummy_key = "response" + if response_is_not_object: + schema = { + "type": "object", + "properties": {dummy_key: config.response_format}, + "additionalProperties": False, + } + else: + schema = config.response_format + + response_format = { + "type": "json_schema", + "json_schema": {"name": "response_format", "schema": schema}, + } + + args = { + "inputs": call_input, + "output": call_output, + } + from openai import OpenAI + + client = OpenAI() + # Silly hack to get around issue in tests: + create = client.chat.completions.create + if hasattr(create, "resolve_fn"): + create = partial(create.resolve_fn, self=client.chat.completions) + completion = create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": json.dumps(args)}, + ], + response_format=response_format, + ) + res = json.loads(completion.choices[0].message.content) + if response_is_not_object: + res = res[dummy_key] + return res + + +@action_task +def contains_words( + call_input: str, call_output: str, config: ConfiguredContainsWordsAction +) -> bool: + word_set = set(call_output.split(" ")) + return len(set(config.target_words) & word_set) > 0 + + +@action_task +def noop(call_input: str, call_output: str, config: ConfiguredNoopAction) -> None: + pass diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index b1988616c4e..a939c2462f7 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -48,11 +48,18 @@ from clickhouse_connect.driver.client import Client as CHClient from clickhouse_connect.driver.query import QueryResult from clickhouse_connect.driver.summary import QuerySummary +from pydantic import BaseModel from weave.trace_server import clickhouse_trace_server_migrator as wf_migrator from weave.trace_server import environment as wf_env from weave.trace_server import refs_internal as ri from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.action_executor import ( + ActionExecutor, + TaskCtx, + queue_from_addr, +) +from weave.trace_server.actions import TABLE_ACTIONS, get_stale_actions from weave.trace_server.base_object_class_util import ( process_incoming_object, ) @@ -82,6 +89,15 @@ validate_feedback_purge_req, ) from weave.trace_server.ids import generate_id +from weave.trace_server.interface.base_models.action_base_models import ( + ActionDispatchFilter, +) +from weave.trace_server.interface.base_models.base_model_registry import ( + base_model_name, +) +from weave.trace_server.interface.base_models.feedback_base_model_registry import ( + feedback_base_models, +) from weave.trace_server.llm_completion import lite_llm_completion from weave.trace_server.model_providers.model_providers import ( read_model_to_provider_info_map, @@ -176,6 +192,55 @@ class NotFoundError(Exception): } +@dataclasses.dataclass +class CallBatch: + """Represents a batch of calls to be inserted into Clickhouse.""" + + calls: list[list[Any]] + project_id: Optional[str] + call_ids: list[str] # Track IDs of calls in the batch + + def __init__(self) -> None: + self.calls = [] + self.project_id = None + self.call_ids = [] + + def add_call(self, ch_call: CallCHInsertable) -> None: + """Add a call to the batch, ensuring project_id consistency.""" + if self.project_id is None: + self.project_id = ch_call.project_id + elif self.project_id != ch_call.project_id: + raise ValueError("All calls in a batch must have the same project_id") + + parameters = ch_call.model_dump() + row = [] + for key in all_call_insert_columns: + row.append(parameters.get(key, None)) + self.calls.append(row) + self.call_ids.append(ch_call.id) + + def clear(self) -> None: + """Reset the batch to empty state.""" + self.calls = [] + self.project_id = None + self.call_ids = [] + + def __bool__(self) -> bool: + """Return True if the batch has any calls.""" + return bool(self.calls) + + +class ActionsAckBatchReq(BaseModel): + project_id: str + call_ids: list[str] + id: str + succeeded: bool + + +class ActionsAckBatchRes(BaseModel): + id: str + + class ClickHouseTraceServer(tsi.TraceServerInterface): def __init__( self, @@ -185,6 +250,7 @@ def __init__( user: str = "default", password: str = "", database: str = "default", + action_executor_addr: str, use_async_insert: bool = False, ): super().__init__() @@ -195,8 +261,9 @@ def __init__( self._password = password self._database = database self._flush_immediately = True - self._call_batch: list[list[Any]] = [] + self._call_batch = CallBatch() self._use_async_insert = use_async_insert + self._action_executor_addr = action_executor_addr self._model_to_provider_info_map = read_model_to_provider_info_map() @classmethod @@ -209,6 +276,7 @@ def from_env(cls, use_async_insert: bool = False) -> "ClickHouseTraceServer": user=wf_env.wf_clickhouse_user(), password=wf_env.wf_clickhouse_pass(), database=wf_env.wf_clickhouse_database(), + action_executor_addr=wf_env.wf_action_executor(), use_async_insert=use_async_insert, ) @@ -220,7 +288,7 @@ def call_batch(self) -> Iterator[None]: yield self._flush_calls() finally: - self._call_batch = [] + self._call_batch.clear() self._flush_immediately = True # Creates a new call @@ -1335,8 +1403,17 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: assert_non_null_wb_user_id(req) validate_feedback_create_req(req) + feedback_type = req.feedback_type + req_payload = req.payload + + for feedback_base_model in feedback_base_models: + if base_model_name(feedback_base_model) == feedback_type: + req_payload = feedback_base_model.model_validate( + req_payload + ).model_dump() + break + # Augment emoji with alias. - res_payload = {} if req.feedback_type == "wandb.reaction.1": em = req.payload["emoji"] if emoji.emoji_count(em) != 1: @@ -1347,12 +1424,12 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: detoned = detone_emojis(em) req.payload["detoned"] = detoned req.payload["detoned_alias"] = emoji.demojize(detoned) - res_payload = req.payload + req_payload = req.payload feedback_id = generate_id() created_at = datetime.datetime.now(ZoneInfo("UTC")) # TODO: Any validation on weave_ref? - payload = _dict_value_to_dump(req.payload) + payload = _dict_value_to_dump(req_payload) MAX_PAYLOAD = 1024 if len(payload) > MAX_PAYLOAD: raise InvalidRequest("Feedback payload too large") @@ -1363,12 +1440,8 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: "wb_user_id": req.wb_user_id, "creator": req.creator, "feedback_type": req.feedback_type, - "payload": req.payload, + "payload": req_payload, "created_at": created_at, - "annotation_ref": req.annotation_ref, - "runnable_ref": req.runnable_ref, - "call_ref": req.call_ref, - "trigger_ref": req.trigger_ref, } prepared = TABLE_FEEDBACK.insert(row).prepare(database_type="clickhouse") self._insert(TABLE_FEEDBACK.name, prepared.data, prepared.column_names) @@ -1376,7 +1449,7 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: id=feedback_id, created_at=created_at, wb_user_id=req.wb_user_id, - payload=res_payload, + payload=req_payload, ) def feedback_query(self, req: tsi.FeedbackQueryReq) -> tsi.FeedbackQueryRes: @@ -1406,6 +1479,89 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: self.ch_client.query(prepared.sql, prepared.parameters) return tsi.FeedbackPurgeRes() + def actions_execute_batch( + self, req: tsi.ActionsExecuteBatchReq + ) -> tsi.ActionsExecuteBatchRes: + # NOTE: Clients should try to generate their own ids, and retry using the same IDs in case of failures. + # That way we can avoid unnecessary duplicates if the server fails after inserting the batch into CH + # but before inserting into the action queue. + + # Step 1: Prepare data to insert into CH actions table and queue. + id = req.id or generate_id() + created_at = datetime.datetime.now(ZoneInfo("UTC")) + rows: list[Row] = [ + { + "project_id": req.project_id, + "call_id": call_id, + "id": id, + "configured_action": req.configured_action_ref, + "created_at": created_at, + } + for call_id in req.call_ids + ] + task_ctxs: list[TaskCtx] = [ + {"project_id": req.project_id, "call_id": call_id, "id": id} + for call_id in req.call_ids + ] + # Step 2: Potential shortcut: if there is only one call, we can do the action right away. + if len(req.call_ids) == 1: + self.action_executor.do_now(task_ctxs[0], req.configured_action_ref) + return tsi.ActionsExecuteBatchRes(id=id) + + # Step 3: Normal case: enqueue the actions in CH and the worker queue. + prepared = TABLE_ACTIONS.insert_many(rows).prepare(database_type="clickhouse") + self._insert(TABLE_ACTIONS.name, prepared.data, prepared.column_names) + + for task_ctx in task_ctxs: + self.action_executor.enqueue(task_ctx, req.configured_action_ref) + return tsi.ActionsExecuteBatchRes(id=id) + + def actions_ack_batch(self, req: ActionsAckBatchReq) -> ActionsAckBatchRes: + received_at = datetime.datetime.now(ZoneInfo("UTC")) + rows: list[Row] = [ + { + "project_id": req.project_id, + "call_id": call_id, + "id": req.id, + "finished_at": received_at if req.succeeded else None, + "failed_at": received_at if not req.succeeded else None, + } + for call_id in req.call_ids + ] + prepared = TABLE_ACTIONS.insert_many(rows).prepare(database_type="clickhouse") + self._insert(TABLE_ACTIONS.name, prepared.data, prepared.column_names) + + return ActionsAckBatchRes(id=req.id) + + # NOTE: This is a private admin function, meant to be invoked by an action cleaner. + # The action cleaner's job is to find actions that have been in the action queue for too long, and requeue them. + def _actions_requeue_stale(self) -> None: + try: + prepared_select = get_stale_actions( + older_than=datetime.datetime.now(ZoneInfo("UTC")) + - datetime.timedelta(hours=1) + ) + query_result = self._query(prepared_select.sql, prepared_select.parameters) + rows = TABLE_ACTIONS.tuples_to_rows( + query_result.result_rows, prepared_select.fields + ) + for row in rows: + try: + self.action_executor.enqueue( + { + "project_id": row["project_id"], # type: ignore + "call_id": row["call_id"], # type: ignore + "id": row["id"], # type: ignore + }, + row["configured_action"], # type: ignore + row.get("trigger_ref"), # type: ignore + ) + except Exception as e: + logger.error(f"Failed to requeue action: {row}. Error: {str(e)}") + except Exception as e: + logger.error(f"Error in _actions_requeue_stale: {str(e)}") + raise + def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: @@ -1471,6 +1627,14 @@ def ch_client(self) -> CHClient: self._thread_local.ch_client = self._mint_client() return self._thread_local.ch_client + @property + def action_executor(self) -> ActionExecutor: + if not hasattr(self._thread_local, "action_executor"): + self._thread_local.action_executor = queue_from_addr( + self._action_executor_addr + ) + return self._thread_local.action_executor + def _mint_client(self) -> CHClient: client = clickhouse_connect.get_client( host=self._host, @@ -1484,9 +1648,6 @@ def _mint_client(self) -> CHClient: client.database = self._database return client - # def __del__(self) -> None: - # self.ch_client.close() - def _insert_call_batch(self, batch: list) -> None: if batch: settings = {} @@ -1765,23 +1926,147 @@ def _insert( raise def _insert_call(self, ch_call: CallCHInsertable) -> None: - parameters = ch_call.model_dump() - row = [] - for key in all_call_insert_columns: - row.append(parameters.get(key, None)) - self._call_batch.append(row) + self._call_batch.add_call(ch_call) if self._flush_immediately: self._flush_calls() + def _get_matched_calls_for_filters( + self, project_id: str, call_ids: list[str] + ) -> list[tuple[ActionDispatchFilter, str, list[tsi.CallSchema]]]: + """Helper function to get calls that match action filters. + + Returns a list of tuples containing (filter, matched_calls) pairs. + """ + # Get all action filters for the project + filter_req = tsi.ObjQueryReq( + project_id=project_id, + filter=tsi.ObjectVersionFilter( + base_object_classes=[base_model_name(ActionDispatchFilter)], + is_op=False, + latest_only=True, # IMPORTANT: Always keep this. + ), + ) + filter_res = self.objs_query(filter_req) + filters: list[tuple[ActionDispatchFilter, str]] = [ + ( + ActionDispatchFilter.model_validate(obj.val), + ri.InternalObjectRef( + project_id=project_id, + name=obj.object_id, + version=obj.digest, + ).uri(), + ) + for obj in filter_res.objs + ] + + if not filters: + return [] + + # Get all finished calls from the batch + calls_query_filter = tsi.Query.model_validate( + { + "$expr": { + "$and": [ + { + "$in": [ + {"$getField": "id"}, + [{"$literal": id} for id in call_ids], + ] + }, + { + "$not": ( + { + "$eq": [ + {"$getField": "started_at"}, + {"$literal": None}, + ] + }, + ) + }, + { + "$not": ( + { + "$eq": [ + {"$getField": "ended_at"}, + {"$literal": None}, + ] + }, + ) + }, + ] + } + } + ) + calls_query_req = tsi.CallsQueryReq( + project_id=project_id, + query=calls_query_filter, + ) + calls_res = self.calls_query(calls_query_req) + + # TODO: this can be lifted to the top once core's deps are updated + import xxhash + + # Match calls to filters + matched_filters_and_calls = [] + for filter, filter_ref in filters: + if filter.disabled: + continue + calls_with_refs = [ + (call, ri.parse_internal_uri(call.op_name)) for call in calls_res.calls + ] + matched_calls = [ + c[0] + for c in calls_with_refs + if isinstance(c[1], ri.InternalObjectRef) + and c[1].name == filter.op_name + ] + + if filter.sample_rate < 1.0: + matched_calls = [ + call + for call in matched_calls + if abs(xxhash.xxh32(call.id).intdigest()) % 10000 + < filter.sample_rate * 10000 + ] + + if matched_calls: + matched_filters_and_calls.append((filter, filter_ref, matched_calls)) + + return matched_filters_and_calls + def _flush_calls(self) -> None: + if not self._call_batch: + return + calls = self._call_batch.calls + if not calls: + return + project_id = self._call_batch.project_id + + assert project_id is not None + try: - self._insert_call_batch(self._call_batch) + self._insert_call_batch(calls) except InsertTooLarge: logger.info("Retrying with large objects stripped.") - batch = self._strip_large_values(self._call_batch) + batch = self._strip_large_values(calls) self._insert_call_batch(batch) - self._call_batch = [] + # Find and process matched calls for each filter + matched_filters_and_calls = self._get_matched_calls_for_filters( + project_id, self._call_batch.call_ids + ) + + for filter, filter_ref, matched_calls in matched_filters_and_calls: + self.actions_execute_batch( + tsi.ActionsExecuteBatchReq( + project_id=project_id, + call_ids=[call.id for call in matched_calls], + configured_action_ref=filter.configured_action_ref, + trigger_ref=filter_ref, + ) + ) + + self._call_batch.clear() def _strip_large_values(self, batch: list[list[Any]]) -> list[list[Any]]: """ @@ -2020,21 +2305,6 @@ def _process_parameters( return parameters -# def _partial_obj_schema_to_ch_obj( -# partial_obj: tsi.ObjSchemaForInsert, -# ) -> ObjCHInsertable: -# version_hash = version_hash_for_object(partial_obj) - -# return ObjCHInsertable( -# id=uuid.uuid4(), -# project_id=partial_obj.project_id, -# name=partial_obj.name, -# type="unknown", -# refs=[], -# val=json.dumps(partial_obj.val), -# ) - - def get_type(val: Any) -> str: if val == None: return "none" diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index 7e085b8f75e..dd81f002e6d 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -346,6 +346,13 @@ def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: cost["pricing_level_id"] = original_project_id return res + def actions_execute_batch( + self, req: tsi.ActionsExecuteBatchReq + ) -> tsi.ActionsExecuteBatchRes: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + res = self._ref_apply(self._internal_trace_server.actions_execute_batch, req) + return res + def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: diff --git a/weave/trace_server/interface/base_models/action_base_models.py b/weave/trace_server/interface/base_models/action_base_models.py new file mode 100644 index 00000000000..476293ceca0 --- /dev/null +++ b/weave/trace_server/interface/base_models/action_base_models.py @@ -0,0 +1,44 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel + + +class ConfiguredLlmJudgeAction(BaseModel): + action_type: Literal["llm_judge"] = "llm_judge" + model: Literal["gpt-4o", "gpt-4o-mini"] + prompt: str + response_format: Optional[dict[str, Any]] + + +class ConfiguredContainsWordsAction(BaseModel): + action_type: Literal["contains_words"] = "contains_words" + target_words: list[str] + + +class ConfiguredWordCountAction(BaseModel): + action_type: Literal["wordcount"] = "wordcount" + + +class ConfiguredNoopAction(BaseModel): + action_type: Literal["noop"] = "noop" + + +ActionConfigType = Union[ + ConfiguredLlmJudgeAction, + ConfiguredContainsWordsAction, + ConfiguredWordCountAction, + ConfiguredNoopAction, +] + + +class ConfiguredAction(BaseModel): + name: str + config: ActionConfigType + + +# CRITICAL! When saving this object, always set the object_id == "op_id-action_id" +class ActionDispatchFilter(BaseModel): + op_name: str + sample_rate: float + configured_action_ref: str + disabled: Optional[bool] = False diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 94c0d8c6ecc..85d4117695f 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1007,10 +1007,6 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: "feedback_type": req.feedback_type, "payload": req.payload, "created_at": created_at, - "annotation_ref": req.annotation_ref, - "runnable_ref": req.runnable_ref, - "call_ref": req.call_ref, - "trigger_ref": req.trigger_ref, } conn, cursor = get_conn_cursor(self.db_path) with self.lock: @@ -1053,6 +1049,13 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: conn.commit() return tsi.FeedbackPurgeRes() + def actions_execute_batch( + self, req: tsi.ActionsExecuteBatchReq + ) -> tsi.ActionsExecuteBatchRes: + raise NotImplementedError( + "actions_execute_batch is not implemented for SQLite trace server" + ) + def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: conn, cursor = get_conn_cursor(self.db_path) digest = bytes_digest(req.content) diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 88b849d644c..55c6db51b79 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -686,18 +686,6 @@ class FeedbackCreateReq(BaseModel): } ] ) - annotation_ref: Optional[str] = Field( - default=None, examples=["weave:///entity/project/object/name:digest"] - ) - runnable_ref: Optional[str] = Field( - default=None, examples=["weave:///entity/project/op/name:digest"] - ) - call_ref: Optional[str] = Field( - default=None, examples=["weave:///entity/project/call/call_id"] - ) - trigger_ref: Optional[str] = Field( - default=None, examples=["weave:///entity/project/object/name:digest"] - ) # wb_user_id is automatically populated by the server wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) @@ -849,6 +837,20 @@ class CostPurgeRes(BaseModel): pass +class ActionsExecuteBatchReq(BaseModel): + project_id: str + call_ids: list[str] + configured_action_ref: str + trigger_ref: Optional[str] = None + # `id` is here so that clients can potentially guarantee idempotence. + # Repeated calls with the same id will not result in duplicate actions. + id: Optional[str] = None + + +class ActionsExecuteBatchRes(BaseModel): + id: str + + class TraceServerInterface(Protocol): def ensure_project_exists( self, entity: str, project: str @@ -890,5 +892,11 @@ def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ... def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: ... def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ... def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ... + + # Action API + def actions_execute_batch( + self, req: ActionsExecuteBatchReq + ) -> ActionsExecuteBatchRes: ... + # Execute LLM API def completions_create(self, req: CompletionsCreateReq) -> CompletionsCreateRes: ... diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index 34b906a560c..6f1e83087cf 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -1,7 +1,7 @@ import io import json import logging -from typing import Any, Iterator, List, Optional, Tuple, Type, Union, cast +from typing import Any, Iterator, List, Optional, Tuple, Type, TypeVar, Union, cast import tenacity from pydantic import BaseModel, ValidationError @@ -214,13 +214,16 @@ def _generic_request_executor( return r + ReqType = TypeVar("ReqType", bound=BaseModel) + ResType = TypeVar("ResType", bound=BaseModel) + def _generic_request( self, url: str, - req: BaseModel, - req_model: Type[BaseModel], - res_model: Type[BaseModel], - ) -> BaseModel: + req: Union[ReqType, dict[str, Any]], + req_model: Type[ReqType], + res_model: Type[ResType], + ) -> ResType: if isinstance(req, dict): req = req_model.model_validate(req) r = self._generic_request_executor(url, req) @@ -229,10 +232,10 @@ def _generic_request( def _generic_stream_request( self, url: str, - req: BaseModel, - req_model: Type[BaseModel], - res_model: Type[BaseModel], - ) -> Iterator[BaseModel]: + req: ReqType, + req_model: Type[ReqType], + res_model: Type[ResType], + ) -> Iterator[ResType]: if isinstance(req, dict): req = req_model.model_validate(req) r = self._generic_request_executor(url, req, stream=True) @@ -527,6 +530,16 @@ def feedback_purge( "/feedback/purge", req, tsi.FeedbackPurgeReq, tsi.FeedbackPurgeRes ) + def actions_execute_batch( + self, req: Union[tsi.ActionsExecuteBatchReq, dict[str, Any]] + ) -> tsi.ActionsExecuteBatchRes: + return self._generic_request( + "/actions/execute_batch", + req, + tsi.ActionsExecuteBatchReq, + tsi.ActionsExecuteBatchRes, + ) + # Cost API def cost_query( self, req: Union[tsi.CostQueryReq, dict[str, Any]] From f3b9e66ca7665621eb812ee49965fa700e2b37db Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 13:27:13 -0800 Subject: [PATCH 017/120] Initial Cleanup --- weave/trace_server/sqlite_trace_server.py | 4 ++++ weave/trace_server/trace_server_interface.py | 19 +++++++++++------ .../remote_http_trace_server.py | 21 ++++++++----------- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 85d4117695f..76691732e10 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1007,6 +1007,10 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: "feedback_type": req.feedback_type, "payload": req.payload, "created_at": created_at, + "annotation_ref": req.annotation_ref, + "runnable_ref": req.runnable_ref, + "call_ref": req.call_ref, + "trigger_ref": req.trigger_ref, } conn, cursor = get_conn_cursor(self.db_path) with self.lock: diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 55c6db51b79..ea9027abbda 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -686,7 +686,18 @@ class FeedbackCreateReq(BaseModel): } ] ) - + annotation_ref: Optional[str] = Field( + default=None, examples=["weave:///entity/project/object/name:digest"] + ) + runnable_ref: Optional[str] = Field( + default=None, examples=["weave:///entity/project/op/name:digest"] + ) + call_ref: Optional[str] = Field( + default=None, examples=["weave:///entity/project/call/call_id"] + ) + trigger_ref: Optional[str] = Field( + default=None, examples=["weave:///entity/project/object/name:digest"] + ) # wb_user_id is automatically populated by the server wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) @@ -841,14 +852,10 @@ class ActionsExecuteBatchReq(BaseModel): project_id: str call_ids: list[str] configured_action_ref: str - trigger_ref: Optional[str] = None - # `id` is here so that clients can potentially guarantee idempotence. - # Repeated calls with the same id will not result in duplicate actions. - id: Optional[str] = None class ActionsExecuteBatchRes(BaseModel): - id: str + pass class TraceServerInterface(Protocol): diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index 6f1e83087cf..0f745644848 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -1,7 +1,7 @@ import io import json import logging -from typing import Any, Iterator, List, Optional, Tuple, Type, TypeVar, Union, cast +from typing import Any, Iterator, List, Optional, Tuple, Type, Union, cast import tenacity from pydantic import BaseModel, ValidationError @@ -214,16 +214,13 @@ def _generic_request_executor( return r - ReqType = TypeVar("ReqType", bound=BaseModel) - ResType = TypeVar("ResType", bound=BaseModel) - def _generic_request( self, url: str, - req: Union[ReqType, dict[str, Any]], - req_model: Type[ReqType], - res_model: Type[ResType], - ) -> ResType: + req: BaseModel, + req_model: Type[BaseModel], + res_model: Type[BaseModel], + ) -> BaseModel: if isinstance(req, dict): req = req_model.model_validate(req) r = self._generic_request_executor(url, req) @@ -232,10 +229,10 @@ def _generic_request( def _generic_stream_request( self, url: str, - req: ReqType, - req_model: Type[ReqType], - res_model: Type[ResType], - ) -> Iterator[ResType]: + req: BaseModel, + req_model: Type[BaseModel], + res_model: Type[BaseModel], + ) -> Iterator[BaseModel]: if isinstance(req, dict): req = req_model.model_validate(req) r = self._generic_request_executor(url, req, stream=True) From 54636987a9200aa514e33a15556cfda6ca53aa03 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 14:01:37 -0800 Subject: [PATCH 018/120] Initial Action changes --- weave/actions_worker/tasks.py | 218 +++--------------- .../base_models/action_base_models.py | 3 + 2 files changed, 37 insertions(+), 184 deletions(-) diff --git a/weave/actions_worker/tasks.py b/weave/actions_worker/tasks.py index 5e3b1f29f32..9b70171c73b 100644 --- a/weave/actions_worker/tasks.py +++ b/weave/actions_worker/tasks.py @@ -1,23 +1,13 @@ import json -from functools import partial, wraps -from typing import Any, Callable, Optional, TypeVar +from functools import partial +from typing import Any, Tuple -from weave.actions_worker.celery_app import app -from weave.trace_server.action_executor import TaskCtx -from weave.trace_server.clickhouse_trace_server_batched import ( - ActionsAckBatchReq, - ClickHouseTraceServer, -) +from openai import OpenAI + +from weave.trace_server.feedback import RunnablePayloadSchema from weave.trace_server.interface.base_models.action_base_models import ( - ConfiguredAction, - ConfiguredContainsWordsAction, + ActionConfigType, ConfiguredLlmJudgeAction, - ConfiguredNoopAction, - ConfiguredWordCountAction, -) -from weave.trace_server.interface.base_models.base_model_registry import base_model_name -from weave.trace_server.interface.base_models.feedback_base_model_registry import ( - MachineScore, ) from weave.trace_server.refs_internal import ( InternalCallRef, @@ -27,172 +17,33 @@ ) from weave.trace_server.trace_server_interface import ( CallSchema, - CallsFilter, - CallsQueryReq, FeedbackCreateReq, - RefsReadBatchReq, ) -WEAVE_ACTION_EXECUTOR_PACEHOLDER_ID = "WEAVE_ACTION_EXECUTOR" - - -def ack_on_clickhouse(ctx: TaskCtx, succeeded: bool) -> None: - project_id = ctx["project_id"] - call_id = ctx["call_id"] - id = ctx["id"] - ClickHouseTraceServer.from_env().actions_ack_batch( - ActionsAckBatchReq( - project_id=project_id, call_ids=[call_id], id=id, succeeded=succeeded - ) - ) - def publish_results_as_feedback( - ctx: TaskCtx, + target_call: CallSchema, + runnable_ref: str, result: Any, - configured_action_ref: str, - trigger_ref: Optional[str] = None, -) -> None: - project_id = ctx["project_id"] - call_id = ctx["call_id"] - id = ctx["id"] - call_ref = InternalCallRef(project_id, call_id).uri() - parsed_action_ref = parse_internal_uri(configured_action_ref) +) -> FeedbackCreateReq: + project_id = target_call.project_id + call_id = target_call.id + weave_ref = InternalCallRef(project_id, call_id).uri() + parsed_action_ref = parse_internal_uri(runnable_ref) if not isinstance(parsed_action_ref, (InternalObjectRef, InternalOpRef)): - raise ValueError(f"Invalid action ref: {configured_action_ref}") + raise ValueError(f"Invalid action ref: {runnable_ref}") action_name = parsed_action_ref.name - digest = parsed_action_ref.version - ClickHouseTraceServer.from_env().feedback_create( - FeedbackCreateReq( - project_id=project_id, - weave_ref=call_ref, - feedback_type=base_model_name(MachineScore), - payload=MachineScore( - runnable_ref=configured_action_ref, - value={action_name: {digest: result}}, - trigger_ref=trigger_ref, - ).model_dump(), - wb_user_id=WEAVE_ACTION_EXECUTOR_PACEHOLDER_ID, - ) - ) - - -def resolve_action_ref(configured_action_ref: str) -> ConfiguredAction: - server = ClickHouseTraceServer.from_env() - action_dict_res = server.refs_read_batch( - RefsReadBatchReq(refs=[configured_action_ref]) - ) - action_dict = action_dict_res.vals[0] - assert isinstance(action_dict, dict) - action = ConfiguredAction.model_validate(action_dict) - return action - -def resolve_call(ctx: TaskCtx) -> CallSchema: - project_id, call_id = ctx["project_id"], ctx["call_id"] - server = ClickHouseTraceServer.from_env() - calls_query_res = server.calls_query( - CallsQueryReq( - project_id=project_id, filter=CallsFilter(call_ids=[call_id]), limit=1 - ) + return FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type="wandb.runnable." + action_name, + runnable_ref=runnable_ref, + payload=RunnablePayloadSchema(output=result).model_dump(), ) - return calls_query_res.calls[0] - -ActionConfigT = TypeVar("ActionConfigT") -ActionResultT = TypeVar("ActionResultT") - -def action_task( - func: Callable[[str, str, ActionConfigT], ActionResultT], -) -> Callable[[TaskCtx, str, str, str, ActionConfigT, Optional[str]], ActionResultT]: - @wraps(func) - def wrapper( - ctx: TaskCtx, - call_input: str, - call_output: str, - configured_action_ref: str, - configured_action: ActionConfigT, - trigger_ref: Optional[str] = None, - ) -> ActionResultT: - success = True - try: - result = func(call_input, call_output, configured_action) - publish_results_as_feedback(ctx, result, configured_action_ref, trigger_ref) - # logging.info(f"Successfully ran {func.__name__}") - # logging.info(f"Result: {result}") - except Exception as e: - success = False - raise e - finally: - ack_on_clickhouse(ctx, success) - return result - - return wrapper - - -@app.task() -def do_task( - ctx: TaskCtx, configured_action_ref: str, trigger_ref: Optional[str] = None -) -> None: - action = resolve_action_ref(configured_action_ref) - call = resolve_call(ctx) - call_input = json.dumps(call.inputs) - call_output = call.output - if not isinstance(call_output, str): - call_output = json.dumps(call_output) - - if action.config.action_type == "wordcount": - wordcount( - ctx, - call_input, - call_output, - configured_action_ref, - action.config, - trigger_ref, - ) - elif action.config.action_type == "llm_judge": - llm_judge( - ctx, - call_input, - call_output, - configured_action_ref, - action.config, - trigger_ref, - ) - elif action.config.action_type == "noop": - noop( - ctx, - call_input, - call_output, - configured_action_ref, - action.config, - trigger_ref, - ) - elif action.config.action_type == "contains_words": - contains_words( - ctx, - call_input, - call_output, - configured_action_ref, - action.config, - trigger_ref, - ) - else: - raise ValueError(f"Unknown action type: {action.config.action_type}") - - -@action_task -def wordcount( - call_input: str, call_output: str, config: ConfiguredWordCountAction -) -> int: - return len(call_output.split(" ")) - - -@action_task -def llm_judge( - call_input: str, call_output: str, config: ConfiguredLlmJudgeAction -) -> str: +def do_llm_judge_action(config: ConfiguredLlmJudgeAction, call: CallSchema) -> Any: model = config.model system_prompt = config.prompt if config.response_format is None: @@ -215,10 +66,9 @@ def llm_judge( } args = { - "inputs": call_input, - "output": call_output, + "inputs": call.inputs, + "output": call.output, } - from openai import OpenAI client = OpenAI() # Silly hack to get around issue in tests: @@ -239,14 +89,14 @@ def llm_judge( return res -@action_task -def contains_words( - call_input: str, call_output: str, config: ConfiguredContainsWordsAction -) -> bool: - word_set = set(call_output.split(" ")) - return len(set(config.target_words) & word_set) > 0 - - -@action_task -def noop(call_input: str, call_output: str, config: ConfiguredNoopAction) -> None: - pass +def do_action( + configured_action_ref: str, action_config: ActionConfigType, call: CallSchema +) -> Tuple[Any, FeedbackCreateReq]: + runnable_ref = None + if isinstance(action_config, ConfiguredLlmJudgeAction): + result = do_llm_judge_action(action_config, call) + runnable_ref = configured_action_ref + else: + raise ValueError(f"Unsupported action config: {action_config}") + req = publish_results_as_feedback(call, runnable_ref, result) + return result, req diff --git a/weave/trace_server/interface/base_models/action_base_models.py b/weave/trace_server/interface/base_models/action_base_models.py index 476293ceca0..0b9dbb5c475 100644 --- a/weave/trace_server/interface/base_models/action_base_models.py +++ b/weave/trace_server/interface/base_models/action_base_models.py @@ -2,6 +2,9 @@ from pydantic import BaseModel +# class BaseActionModel(BaseModel): +# action_type: str + class ConfiguredLlmJudgeAction(BaseModel): action_type: Literal["llm_judge"] = "llm_judge" From 001b9d56259d131b43d58c2a53bb7bd79ebd1a59 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 14:33:33 -0800 Subject: [PATCH 019/120] Initial Refactor --- .../clickhouse_trace_server_batched.py | 337 ++---------------- 1 file changed, 34 insertions(+), 303 deletions(-) diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index a939c2462f7..0ae6eedd9a3 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -48,18 +48,11 @@ from clickhouse_connect.driver.client import Client as CHClient from clickhouse_connect.driver.query import QueryResult from clickhouse_connect.driver.summary import QuerySummary -from pydantic import BaseModel from weave.trace_server import clickhouse_trace_server_migrator as wf_migrator from weave.trace_server import environment as wf_env from weave.trace_server import refs_internal as ri from weave.trace_server import trace_server_interface as tsi -from weave.trace_server.action_executor import ( - ActionExecutor, - TaskCtx, - queue_from_addr, -) -from weave.trace_server.actions import TABLE_ACTIONS, get_stale_actions from weave.trace_server.base_object_class_util import ( process_incoming_object, ) @@ -89,15 +82,6 @@ validate_feedback_purge_req, ) from weave.trace_server.ids import generate_id -from weave.trace_server.interface.base_models.action_base_models import ( - ActionDispatchFilter, -) -from weave.trace_server.interface.base_models.base_model_registry import ( - base_model_name, -) -from weave.trace_server.interface.base_models.feedback_base_model_registry import ( - feedback_base_models, -) from weave.trace_server.llm_completion import lite_llm_completion from weave.trace_server.model_providers.model_providers import ( read_model_to_provider_info_map, @@ -192,55 +176,6 @@ class NotFoundError(Exception): } -@dataclasses.dataclass -class CallBatch: - """Represents a batch of calls to be inserted into Clickhouse.""" - - calls: list[list[Any]] - project_id: Optional[str] - call_ids: list[str] # Track IDs of calls in the batch - - def __init__(self) -> None: - self.calls = [] - self.project_id = None - self.call_ids = [] - - def add_call(self, ch_call: CallCHInsertable) -> None: - """Add a call to the batch, ensuring project_id consistency.""" - if self.project_id is None: - self.project_id = ch_call.project_id - elif self.project_id != ch_call.project_id: - raise ValueError("All calls in a batch must have the same project_id") - - parameters = ch_call.model_dump() - row = [] - for key in all_call_insert_columns: - row.append(parameters.get(key, None)) - self.calls.append(row) - self.call_ids.append(ch_call.id) - - def clear(self) -> None: - """Reset the batch to empty state.""" - self.calls = [] - self.project_id = None - self.call_ids = [] - - def __bool__(self) -> bool: - """Return True if the batch has any calls.""" - return bool(self.calls) - - -class ActionsAckBatchReq(BaseModel): - project_id: str - call_ids: list[str] - id: str - succeeded: bool - - -class ActionsAckBatchRes(BaseModel): - id: str - - class ClickHouseTraceServer(tsi.TraceServerInterface): def __init__( self, @@ -250,7 +185,6 @@ def __init__( user: str = "default", password: str = "", database: str = "default", - action_executor_addr: str, use_async_insert: bool = False, ): super().__init__() @@ -261,9 +195,8 @@ def __init__( self._password = password self._database = database self._flush_immediately = True - self._call_batch = CallBatch() + self._call_batch: list[list[Any]] = [] self._use_async_insert = use_async_insert - self._action_executor_addr = action_executor_addr self._model_to_provider_info_map = read_model_to_provider_info_map() @classmethod @@ -276,7 +209,6 @@ def from_env(cls, use_async_insert: bool = False) -> "ClickHouseTraceServer": user=wf_env.wf_clickhouse_user(), password=wf_env.wf_clickhouse_pass(), database=wf_env.wf_clickhouse_database(), - action_executor_addr=wf_env.wf_action_executor(), use_async_insert=use_async_insert, ) @@ -288,7 +220,7 @@ def call_batch(self) -> Iterator[None]: yield self._flush_calls() finally: - self._call_batch.clear() + self._call_batch = [] self._flush_immediately = True # Creates a new call @@ -1403,17 +1335,8 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: assert_non_null_wb_user_id(req) validate_feedback_create_req(req) - feedback_type = req.feedback_type - req_payload = req.payload - - for feedback_base_model in feedback_base_models: - if base_model_name(feedback_base_model) == feedback_type: - req_payload = feedback_base_model.model_validate( - req_payload - ).model_dump() - break - # Augment emoji with alias. + res_payload = {} if req.feedback_type == "wandb.reaction.1": em = req.payload["emoji"] if emoji.emoji_count(em) != 1: @@ -1424,12 +1347,12 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: detoned = detone_emojis(em) req.payload["detoned"] = detoned req.payload["detoned_alias"] = emoji.demojize(detoned) - req_payload = req.payload + res_payload = req.payload feedback_id = generate_id() created_at = datetime.datetime.now(ZoneInfo("UTC")) # TODO: Any validation on weave_ref? - payload = _dict_value_to_dump(req_payload) + payload = _dict_value_to_dump(req.payload) MAX_PAYLOAD = 1024 if len(payload) > MAX_PAYLOAD: raise InvalidRequest("Feedback payload too large") @@ -1440,7 +1363,7 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: "wb_user_id": req.wb_user_id, "creator": req.creator, "feedback_type": req.feedback_type, - "payload": req_payload, + "payload": req.payload, "created_at": created_at, } prepared = TABLE_FEEDBACK.insert(row).prepare(database_type="clickhouse") @@ -1449,7 +1372,7 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: id=feedback_id, created_at=created_at, wb_user_id=req.wb_user_id, - payload=req_payload, + payload=res_payload, ) def feedback_query(self, req: tsi.FeedbackQueryReq) -> tsi.FeedbackQueryRes: @@ -1482,85 +1405,7 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: def actions_execute_batch( self, req: tsi.ActionsExecuteBatchReq ) -> tsi.ActionsExecuteBatchRes: - # NOTE: Clients should try to generate their own ids, and retry using the same IDs in case of failures. - # That way we can avoid unnecessary duplicates if the server fails after inserting the batch into CH - # but before inserting into the action queue. - - # Step 1: Prepare data to insert into CH actions table and queue. - id = req.id or generate_id() - created_at = datetime.datetime.now(ZoneInfo("UTC")) - rows: list[Row] = [ - { - "project_id": req.project_id, - "call_id": call_id, - "id": id, - "configured_action": req.configured_action_ref, - "created_at": created_at, - } - for call_id in req.call_ids - ] - task_ctxs: list[TaskCtx] = [ - {"project_id": req.project_id, "call_id": call_id, "id": id} - for call_id in req.call_ids - ] - # Step 2: Potential shortcut: if there is only one call, we can do the action right away. - if len(req.call_ids) == 1: - self.action_executor.do_now(task_ctxs[0], req.configured_action_ref) - return tsi.ActionsExecuteBatchRes(id=id) - - # Step 3: Normal case: enqueue the actions in CH and the worker queue. - prepared = TABLE_ACTIONS.insert_many(rows).prepare(database_type="clickhouse") - self._insert(TABLE_ACTIONS.name, prepared.data, prepared.column_names) - - for task_ctx in task_ctxs: - self.action_executor.enqueue(task_ctx, req.configured_action_ref) - return tsi.ActionsExecuteBatchRes(id=id) - - def actions_ack_batch(self, req: ActionsAckBatchReq) -> ActionsAckBatchRes: - received_at = datetime.datetime.now(ZoneInfo("UTC")) - rows: list[Row] = [ - { - "project_id": req.project_id, - "call_id": call_id, - "id": req.id, - "finished_at": received_at if req.succeeded else None, - "failed_at": received_at if not req.succeeded else None, - } - for call_id in req.call_ids - ] - prepared = TABLE_ACTIONS.insert_many(rows).prepare(database_type="clickhouse") - self._insert(TABLE_ACTIONS.name, prepared.data, prepared.column_names) - - return ActionsAckBatchRes(id=req.id) - - # NOTE: This is a private admin function, meant to be invoked by an action cleaner. - # The action cleaner's job is to find actions that have been in the action queue for too long, and requeue them. - def _actions_requeue_stale(self) -> None: - try: - prepared_select = get_stale_actions( - older_than=datetime.datetime.now(ZoneInfo("UTC")) - - datetime.timedelta(hours=1) - ) - query_result = self._query(prepared_select.sql, prepared_select.parameters) - rows = TABLE_ACTIONS.tuples_to_rows( - query_result.result_rows, prepared_select.fields - ) - for row in rows: - try: - self.action_executor.enqueue( - { - "project_id": row["project_id"], # type: ignore - "call_id": row["call_id"], # type: ignore - "id": row["id"], # type: ignore - }, - row["configured_action"], # type: ignore - row.get("trigger_ref"), # type: ignore - ) - except Exception as e: - logger.error(f"Failed to requeue action: {row}. Error: {str(e)}") - except Exception as e: - logger.error(f"Error in _actions_requeue_stale: {str(e)}") - raise + raise NotImplementedError() def completions_create( self, req: tsi.CompletionsCreateReq @@ -1627,14 +1472,6 @@ def ch_client(self) -> CHClient: self._thread_local.ch_client = self._mint_client() return self._thread_local.ch_client - @property - def action_executor(self) -> ActionExecutor: - if not hasattr(self._thread_local, "action_executor"): - self._thread_local.action_executor = queue_from_addr( - self._action_executor_addr - ) - return self._thread_local.action_executor - def _mint_client(self) -> CHClient: client = clickhouse_connect.get_client( host=self._host, @@ -1648,6 +1485,9 @@ def _mint_client(self) -> CHClient: client.database = self._database return client + # def __del__(self) -> None: + # self.ch_client.close() + def _insert_call_batch(self, batch: list) -> None: if batch: settings = {} @@ -1926,147 +1766,23 @@ def _insert( raise def _insert_call(self, ch_call: CallCHInsertable) -> None: - self._call_batch.add_call(ch_call) + parameters = ch_call.model_dump() + row = [] + for key in all_call_insert_columns: + row.append(parameters.get(key, None)) + self._call_batch.append(row) if self._flush_immediately: self._flush_calls() - def _get_matched_calls_for_filters( - self, project_id: str, call_ids: list[str] - ) -> list[tuple[ActionDispatchFilter, str, list[tsi.CallSchema]]]: - """Helper function to get calls that match action filters. - - Returns a list of tuples containing (filter, matched_calls) pairs. - """ - # Get all action filters for the project - filter_req = tsi.ObjQueryReq( - project_id=project_id, - filter=tsi.ObjectVersionFilter( - base_object_classes=[base_model_name(ActionDispatchFilter)], - is_op=False, - latest_only=True, # IMPORTANT: Always keep this. - ), - ) - filter_res = self.objs_query(filter_req) - filters: list[tuple[ActionDispatchFilter, str]] = [ - ( - ActionDispatchFilter.model_validate(obj.val), - ri.InternalObjectRef( - project_id=project_id, - name=obj.object_id, - version=obj.digest, - ).uri(), - ) - for obj in filter_res.objs - ] - - if not filters: - return [] - - # Get all finished calls from the batch - calls_query_filter = tsi.Query.model_validate( - { - "$expr": { - "$and": [ - { - "$in": [ - {"$getField": "id"}, - [{"$literal": id} for id in call_ids], - ] - }, - { - "$not": ( - { - "$eq": [ - {"$getField": "started_at"}, - {"$literal": None}, - ] - }, - ) - }, - { - "$not": ( - { - "$eq": [ - {"$getField": "ended_at"}, - {"$literal": None}, - ] - }, - ) - }, - ] - } - } - ) - calls_query_req = tsi.CallsQueryReq( - project_id=project_id, - query=calls_query_filter, - ) - calls_res = self.calls_query(calls_query_req) - - # TODO: this can be lifted to the top once core's deps are updated - import xxhash - - # Match calls to filters - matched_filters_and_calls = [] - for filter, filter_ref in filters: - if filter.disabled: - continue - calls_with_refs = [ - (call, ri.parse_internal_uri(call.op_name)) for call in calls_res.calls - ] - matched_calls = [ - c[0] - for c in calls_with_refs - if isinstance(c[1], ri.InternalObjectRef) - and c[1].name == filter.op_name - ] - - if filter.sample_rate < 1.0: - matched_calls = [ - call - for call in matched_calls - if abs(xxhash.xxh32(call.id).intdigest()) % 10000 - < filter.sample_rate * 10000 - ] - - if matched_calls: - matched_filters_and_calls.append((filter, filter_ref, matched_calls)) - - return matched_filters_and_calls - def _flush_calls(self) -> None: - if not self._call_batch: - return - calls = self._call_batch.calls - if not calls: - return - project_id = self._call_batch.project_id - - assert project_id is not None - try: - self._insert_call_batch(calls) + self._insert_call_batch(self._call_batch) except InsertTooLarge: logger.info("Retrying with large objects stripped.") - batch = self._strip_large_values(calls) + batch = self._strip_large_values(self._call_batch) self._insert_call_batch(batch) - # Find and process matched calls for each filter - matched_filters_and_calls = self._get_matched_calls_for_filters( - project_id, self._call_batch.call_ids - ) - - for filter, filter_ref, matched_calls in matched_filters_and_calls: - self.actions_execute_batch( - tsi.ActionsExecuteBatchReq( - project_id=project_id, - call_ids=[call.id for call in matched_calls], - configured_action_ref=filter.configured_action_ref, - trigger_ref=filter_ref, - ) - ) - - self._call_batch.clear() + self._call_batch = [] def _strip_large_values(self, batch: list[list[Any]]) -> list[list[Any]]: """ @@ -2305,6 +2021,21 @@ def _process_parameters( return parameters +# def _partial_obj_schema_to_ch_obj( +# partial_obj: tsi.ObjSchemaForInsert, +# ) -> ObjCHInsertable: +# version_hash = version_hash_for_object(partial_obj) + +# return ObjCHInsertable( +# id=uuid.uuid4(), +# project_id=partial_obj.project_id, +# name=partial_obj.name, +# type="unknown", +# refs=[], +# val=json.dumps(partial_obj.val), +# ) + + def get_type(val: Any) -> str: if val == None: return "none" From acea359831fdae80679db743ca350a957fcdc7e8 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 14:52:26 -0800 Subject: [PATCH 020/120] A bunch of name changes and adjustments --- tests/trace/test_actions_e2e.py | 14 ++++----- weave/actions_worker/tasks.py | 12 ++++---- .../clickhouse_trace_server_batched.py | 4 +++ .../base_models/action_base_models.py | 29 +++++++++---------- weave/trace_server/trace_server_interface.py | 3 +- 5 files changed, 33 insertions(+), 29 deletions(-) diff --git a/tests/trace/test_actions_e2e.py b/tests/trace/test_actions_e2e.py index 548ec22ea94..49bbad55432 100644 --- a/tests/trace/test_actions_e2e.py +++ b/tests/trace/test_actions_e2e.py @@ -4,8 +4,8 @@ from weave.trace.refs import ObjectRef from weave.trace.weave_client import WeaveClient from weave.trace_server.interface.base_models.action_base_models import ( - ConfiguredAction, - ConfiguredContainsWordsAction, + Action, + ContainsWordsActionSpec, ) from weave.trace_server.sqlite_trace_server import SqliteTraceServer from weave.trace_server.trace_server_interface import ( @@ -30,10 +30,10 @@ def test_action_execute_workflow(client: WeaveClient): "obj": { "project_id": client._project_id(), "object_id": action_name, - "base_object_class": "ConfiguredAction", - "val": ConfiguredAction( + "base_object_class": "Action", + "val": Action( name="test_action", - config=ConfiguredContainsWordsAction( + config=ContainsWordsActionSpec( target_words=["mindful", "demure"] ), ).model_dump(), @@ -46,7 +46,7 @@ def test_action_execute_workflow(client: WeaveClient): ObjQueryReq.model_validate( { "project_id": client._project_id(), - "filter": {"base_object_classes": ["ConfiguredAction"]}, + "filter": {"base_object_classes": ["Action"]}, } ) ) @@ -110,7 +110,7 @@ def example_op(input: str) -> str: { "project_id": client._project_id(), "call_ids": [call2.id], - "configured_action_ref": action_ref_uri, + "action_ref": action_ref_uri, } ) ) diff --git a/weave/actions_worker/tasks.py b/weave/actions_worker/tasks.py index 9b70171c73b..1fb1122b4d2 100644 --- a/weave/actions_worker/tasks.py +++ b/weave/actions_worker/tasks.py @@ -6,8 +6,8 @@ from weave.trace_server.feedback import RunnablePayloadSchema from weave.trace_server.interface.base_models.action_base_models import ( - ActionConfigType, - ConfiguredLlmJudgeAction, + Action, + LlmJudgeActionSpec, ) from weave.trace_server.refs_internal import ( InternalCallRef, @@ -43,7 +43,7 @@ def publish_results_as_feedback( ) -def do_llm_judge_action(config: ConfiguredLlmJudgeAction, call: CallSchema) -> Any: +def do_llm_judge_action(config: LlmJudgeActionSpec, call: CallSchema) -> Any: model = config.model system_prompt = config.prompt if config.response_format is None: @@ -90,12 +90,12 @@ def do_llm_judge_action(config: ConfiguredLlmJudgeAction, call: CallSchema) -> A def do_action( - configured_action_ref: str, action_config: ActionConfigType, call: CallSchema + action_ref: str, action_config: Action, call: CallSchema ) -> Tuple[Any, FeedbackCreateReq]: runnable_ref = None - if isinstance(action_config, ConfiguredLlmJudgeAction): + if isinstance(action_config, LlmJudgeActionSpec): result = do_llm_judge_action(action_config, call) - runnable_ref = configured_action_ref + runnable_ref = action_ref else: raise ValueError(f"Unsupported action config: {action_config}") req = publish_results_as_feedback(call, runnable_ref, result) diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 0ae6eedd9a3..1d97e6f5ef9 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -1365,6 +1365,10 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: "feedback_type": req.feedback_type, "payload": req.payload, "created_at": created_at, + "annotation_ref": req.annotation_ref, + "runnable_ref": req.runnable_ref, + "call_ref": req.call_ref, + "trigger_ref": req.trigger_ref, } prepared = TABLE_FEEDBACK.insert(row).prepare(database_type="clickhouse") self._insert(TABLE_FEEDBACK.name, prepared.data, prepared.column_names) diff --git a/weave/trace_server/interface/base_models/action_base_models.py b/weave/trace_server/interface/base_models/action_base_models.py index 0b9dbb5c475..c80430a217e 100644 --- a/weave/trace_server/interface/base_models/action_base_models.py +++ b/weave/trace_server/interface/base_models/action_base_models.py @@ -2,46 +2,45 @@ from pydantic import BaseModel -# class BaseActionModel(BaseModel): -# action_type: str - -class ConfiguredLlmJudgeAction(BaseModel): +class LlmJudgeActionSpec(BaseModel): action_type: Literal["llm_judge"] = "llm_judge" model: Literal["gpt-4o", "gpt-4o-mini"] prompt: str response_format: Optional[dict[str, Any]] -class ConfiguredContainsWordsAction(BaseModel): +class ContainsWordsActionSpec(BaseModel): action_type: Literal["contains_words"] = "contains_words" target_words: list[str] -class ConfiguredWordCountAction(BaseModel): +class WordCountActionSpec(BaseModel): action_type: Literal["wordcount"] = "wordcount" -class ConfiguredNoopAction(BaseModel): +class NoopActionSpec(BaseModel): action_type: Literal["noop"] = "noop" -ActionConfigType = Union[ - ConfiguredLlmJudgeAction, - ConfiguredContainsWordsAction, - ConfiguredWordCountAction, - ConfiguredNoopAction, +ActionSpecType = Union[ + LlmJudgeActionSpec, + ContainsWordsActionSpec, + WordCountActionSpec, + NoopActionSpec, ] -class ConfiguredAction(BaseModel): +# TODO: Make this a baseObject +class Action(BaseModel): name: str - config: ActionConfigType + spec: ActionSpecType # CRITICAL! When saving this object, always set the object_id == "op_id-action_id" +# TODO: Make this a baseObject class ActionDispatchFilter(BaseModel): op_name: str sample_rate: float - configured_action_ref: str + action_ref: str disabled: Optional[bool] = False diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index ea9027abbda..ab88f352db6 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -698,6 +698,7 @@ class FeedbackCreateReq(BaseModel): trigger_ref: Optional[str] = Field( default=None, examples=["weave:///entity/project/object/name:digest"] ) + # wb_user_id is automatically populated by the server wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) @@ -850,8 +851,8 @@ class CostPurgeRes(BaseModel): class ActionsExecuteBatchReq(BaseModel): project_id: str + action_ref: str call_ids: list[str] - configured_action_ref: str class ActionsExecuteBatchRes(BaseModel): From f21c6b4ab1792ba90a7f48bdd5e0faca001da817 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 14:55:11 -0800 Subject: [PATCH 021/120] Move actions --- weave/actions_worker/tasks.py | 2 +- .../actions.py} | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) rename weave/trace_server/interface/{base_models/action_base_models.py => base_object_classes/actions.py} (80%) diff --git a/weave/actions_worker/tasks.py b/weave/actions_worker/tasks.py index 1fb1122b4d2..cbd478ec76a 100644 --- a/weave/actions_worker/tasks.py +++ b/weave/actions_worker/tasks.py @@ -5,7 +5,7 @@ from openai import OpenAI from weave.trace_server.feedback import RunnablePayloadSchema -from weave.trace_server.interface.base_models.action_base_models import ( +from weave.trace_server.interface.base_object_classes.actions import ( Action, LlmJudgeActionSpec, ) diff --git a/weave/trace_server/interface/base_models/action_base_models.py b/weave/trace_server/interface/base_object_classes/actions.py similarity index 80% rename from weave/trace_server/interface/base_models/action_base_models.py rename to weave/trace_server/interface/base_object_classes/actions.py index c80430a217e..4e2bc190e6c 100644 --- a/weave/trace_server/interface/base_models/action_base_models.py +++ b/weave/trace_server/interface/base_object_classes/actions.py @@ -1,5 +1,5 @@ from typing import Any, Literal, Optional, Union - +from weave.trace_server.interface.base_object_classes import base_object_def from pydantic import BaseModel @@ -31,16 +31,14 @@ class NoopActionSpec(BaseModel): ] -# TODO: Make this a baseObject -class Action(BaseModel): +class Action(base_object_def.BaseObject): name: str spec: ActionSpecType # CRITICAL! When saving this object, always set the object_id == "op_id-action_id" -# TODO: Make this a baseObject -class ActionDispatchFilter(BaseModel): +class ActionDispatchFilter(base_object_def.BaseObject): op_name: str sample_rate: float - action_ref: str + action_ref: base_object_def.RefStr disabled: Optional[bool] = False From 05c525d268f52e7b52b49a4015e5797d71603b42 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 14:59:11 -0800 Subject: [PATCH 022/120] Move actions again --- weave/actions_worker/tasks.py | 4 ++-- .../interface/base_object_classes/actions.py | 15 +++++---------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/weave/actions_worker/tasks.py b/weave/actions_worker/tasks.py index cbd478ec76a..a2a00408bdb 100644 --- a/weave/actions_worker/tasks.py +++ b/weave/actions_worker/tasks.py @@ -6,7 +6,7 @@ from weave.trace_server.feedback import RunnablePayloadSchema from weave.trace_server.interface.base_object_classes.actions import ( - Action, + ActionDefinition, LlmJudgeActionSpec, ) from weave.trace_server.refs_internal import ( @@ -90,7 +90,7 @@ def do_llm_judge_action(config: LlmJudgeActionSpec, call: CallSchema) -> Any: def do_action( - action_ref: str, action_config: Action, call: CallSchema + action_ref: str, action_config: ActionDefinition, call: CallSchema ) -> Tuple[Any, FeedbackCreateReq]: runnable_ref = None if isinstance(action_config, LlmJudgeActionSpec): diff --git a/weave/trace_server/interface/base_object_classes/actions.py b/weave/trace_server/interface/base_object_classes/actions.py index 4e2bc190e6c..a67f4676069 100644 --- a/weave/trace_server/interface/base_object_classes/actions.py +++ b/weave/trace_server/interface/base_object_classes/actions.py @@ -1,7 +1,9 @@ from typing import Any, Literal, Optional, Union -from weave.trace_server.interface.base_object_classes import base_object_def + from pydantic import BaseModel +from weave.trace_server.interface.base_object_classes import base_object_def + class LlmJudgeActionSpec(BaseModel): action_type: Literal["llm_judge"] = "llm_judge" @@ -31,14 +33,7 @@ class NoopActionSpec(BaseModel): ] -class Action(base_object_def.BaseObject): +# TODO: Make sure we really like this name - it is permanent +class ActionDefinition(base_object_def.BaseObject): name: str spec: ActionSpecType - - -# CRITICAL! When saving this object, always set the object_id == "op_id-action_id" -class ActionDispatchFilter(base_object_def.BaseObject): - op_name: str - sample_rate: float - action_ref: base_object_def.RefStr - disabled: Optional[bool] = False From 94ee62ead9f6fb2ed19c86150f4878867fc54a18 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 15:11:23 -0800 Subject: [PATCH 023/120] beginning to refactor tests --- .../{test_actions_e2e.py => test_actions_lifecycle.py} | 10 +++++----- .../base_object_classes/base_object_registry.py | 2 ++ 2 files changed, 7 insertions(+), 5 deletions(-) rename tests/trace/{test_actions_e2e.py => test_actions_lifecycle.py} (93%) diff --git a/tests/trace/test_actions_e2e.py b/tests/trace/test_actions_lifecycle.py similarity index 93% rename from tests/trace/test_actions_e2e.py rename to tests/trace/test_actions_lifecycle.py index 49bbad55432..8ee9b455516 100644 --- a/tests/trace/test_actions_e2e.py +++ b/tests/trace/test_actions_lifecycle.py @@ -3,8 +3,8 @@ import weave from weave.trace.refs import ObjectRef from weave.trace.weave_client import WeaveClient -from weave.trace_server.interface.base_models.action_base_models import ( - Action, +from weave.trace_server.interface.base_object_classes.actions import ( + ActionDefinition, ContainsWordsActionSpec, ) from weave.trace_server.sqlite_trace_server import SqliteTraceServer @@ -30,10 +30,10 @@ def test_action_execute_workflow(client: WeaveClient): "obj": { "project_id": client._project_id(), "object_id": action_name, - "base_object_class": "Action", - "val": Action( + "base_object_class": "ActionDefinition", + "val": ActionDefinition( name="test_action", - config=ContainsWordsActionSpec( + spec=ContainsWordsActionSpec( target_words=["mindful", "demure"] ), ).model_dump(), diff --git a/weave/trace_server/interface/base_object_classes/base_object_registry.py b/weave/trace_server/interface/base_object_classes/base_object_registry.py index 843598d5979..acfec80293a 100644 --- a/weave/trace_server/interface/base_object_classes/base_object_registry.py +++ b/weave/trace_server/interface/base_object_classes/base_object_registry.py @@ -1,5 +1,6 @@ from typing import Dict, Type +from weave.trace_server.interface.base_object_classes.actions import ActionDefinition from weave.trace_server.interface.base_object_classes.base_object_def import BaseObject from weave.trace_server.interface.base_object_classes.leaderboard import Leaderboard from weave.trace_server.interface.base_object_classes.test_only_example import ( @@ -23,3 +24,4 @@ def register_base_object(cls: Type[BaseObject]) -> None: register_base_object(TestOnlyExample) register_base_object(TestOnlyNestedBaseObject) register_base_object(Leaderboard) +register_base_object(ActionDefinition) From 1a58795329d338430c61e049b9e919df61658f9f Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 15:42:10 -0800 Subject: [PATCH 024/120] Beginning to fix the tests themselves --- tests/trace/test_actions_lifecycle.py | 100 ++++++++------------------ 1 file changed, 28 insertions(+), 72 deletions(-) diff --git a/tests/trace/test_actions_lifecycle.py b/tests/trace/test_actions_lifecycle.py index 8ee9b455516..33b1a26274e 100644 --- a/tests/trace/test_actions_lifecycle.py +++ b/tests/trace/test_actions_lifecycle.py @@ -1,92 +1,52 @@ import pytest import weave -from weave.trace.refs import ObjectRef +from tests.trace.util import client_is_sqlite from weave.trace.weave_client import WeaveClient from weave.trace_server.interface.base_object_classes.actions import ( ActionDefinition, - ContainsWordsActionSpec, ) -from weave.trace_server.sqlite_trace_server import SqliteTraceServer from weave.trace_server.trace_server_interface import ( ActionsExecuteBatchReq, FeedbackCreateReq, - ObjCreateReq, - ObjQueryReq, ) -def test_action_execute_workflow(client: WeaveClient): - is_sqlite = isinstance(client.server._internal_trace_server, SqliteTraceServer) # type: ignore - if is_sqlite: - # dont run this test for sqlite - return +def test_action_lifecycle_simple(client: WeaveClient): + if client_is_sqlite(client): + return pytest.skip("skipping for sqlite") - action_name = "test_action" - # part 1: create the action - digest = client.server.obj_create( - ObjCreateReq.model_validate( - { - "obj": { - "project_id": client._project_id(), - "object_id": action_name, - "base_object_class": "ActionDefinition", - "val": ActionDefinition( - name="test_action", - spec=ContainsWordsActionSpec( - target_words=["mindful", "demure"] - ), - ).model_dump(), - } - } - ) - ).digest + action_name = "my_contains_words_action" - configured_actions = client.server.objs_query( - ObjQueryReq.model_validate( - { - "project_id": client._project_id(), - "filter": {"base_object_classes": ["Action"]}, - } + published_ref = weave.publish( + ActionDefinition( + name=action_name, + spec={ + "action_type": "contains_words", + "target_words": ["mindful", "demure"], + }, ) ) - assert len(configured_actions.objs) == 1 - assert configured_actions.objs[0].digest == digest - action_ref_uri = ObjectRef( - entity=client.entity, - project=client.project, - name=action_name, - _digest=digest, - ).uri() + # Construct the URI + action_ref_uri = published_ref.uri() - # part 2: manually create feedback + # Part 2: Demonstrate manual feedback (this is not user-facing) @weave.op def example_op(input: str) -> str: return input[::-1] - _, call1 = example_op.call("i've been very mindful today") - with pytest.raises(Exception): - client.server.feedback_create( - FeedbackCreateReq.model_validate( - { - "project_id": client._project_id(), - "weave_ref": call1.ref.uri(), - "feedback_type": "MachineScore", - "payload": True, - } - ) - ) + _, call1 = example_op.call("i've been very distracted today") res = client.server.feedback_create( FeedbackCreateReq.model_validate( { "project_id": client._project_id(), "weave_ref": call1.ref.uri(), - "feedback_type": "MachineScore", + "feedback_type": "wandb.runnable." + action_name, + "runnable_ref": action_ref_uri, "payload": { - "runnable_ref": action_ref_uri, - "value": {action_name: {digest: True}}, + "output": False, }, } ) @@ -94,12 +54,10 @@ def example_op(input: str) -> str: feedbacks = list(call1.feedback) assert len(feedbacks) == 1 - assert feedbacks[0].payload == { - "runnable_ref": action_ref_uri, - "value": {action_name: {digest: True}}, - "call_ref": None, - "trigger_ref": None, - } + feedback = feedbacks[0] + assert feedback.feedback_type == "wandb.runnable." + action_name + assert feedback.runnable_ref == action_ref_uri + assert feedback.payload == {"output": False} # Step 3: test that we can in-place execute one action at a time. @@ -109,17 +67,15 @@ def example_op(input: str) -> str: ActionsExecuteBatchReq.model_validate( { "project_id": client._project_id(), - "call_ids": [call2.id], "action_ref": action_ref_uri, + "call_ids": [call2.id], } ) ) feedbacks = list(call2.feedback) assert len(feedbacks) == 1 - assert feedbacks[0].payload == { - "runnable_ref": action_ref_uri, - "value": {action_name: {digest: False}}, - "call_ref": None, - "trigger_ref": None, - } + feedback = feedbacks[0] + assert feedback.feedback_type == "wandb.runnable." + action_name + assert feedback.runnable_ref == action_ref_uri + assert feedback.payload == {"output": True} From 554f82fafb1dda375f3d1b7f923862b643c4f774 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 15:54:05 -0800 Subject: [PATCH 025/120] Added remaining tests --- tests/trace/test_actions_lifecycle.py | 46 +++++++++++++++++++ .../interface/base_object_classes/actions.py | 1 + 2 files changed, 47 insertions(+) diff --git a/tests/trace/test_actions_lifecycle.py b/tests/trace/test_actions_lifecycle.py index 33b1a26274e..39b344d4054 100644 --- a/tests/trace/test_actions_lifecycle.py +++ b/tests/trace/test_actions_lifecycle.py @@ -79,3 +79,49 @@ def example_op(input: str) -> str: assert feedback.feedback_type == "wandb.runnable." + action_name assert feedback.runnable_ref == action_ref_uri assert feedback.payload == {"output": True} + + +def test_action_lifecycle_llm_judge(client: WeaveClient): + if client_is_sqlite(client): + return pytest.skip("skipping for sqlite") + + action_name = "response_is_mindful" + + published_ref = weave.publish( + ActionDefinition( + name=action_name, + spec={ + "action_type": "llm_judge", + "model": "gpt-4o-mini", + "prompt": "Is the response mindful?", + "response_format": {"type": "boolean"}, + }, + ) + ) + + # Construct the URI + action_ref_uri = published_ref.uri() + + @weave.op + def example_op(input: str) -> str: + return input[::-1] + + # Step 2: test that we can in-place execute one action at a time. + _, call = example_op.call("i've been very meditative today") + + res = client.server.actions_execute_batch( + ActionsExecuteBatchReq.model_validate( + { + "project_id": client._project_id(), + "action_ref": action_ref_uri, + "call_ids": [call.id], + } + ) + ) + + feedbacks = list(call.feedback) + assert len(feedbacks) == 1 + feedback = feedbacks[0] + assert feedback.feedback_type == "wandb.runnable." + action_name + assert feedback.runnable_ref == action_ref_uri + assert feedback.payload == {"output": True} diff --git a/weave/trace_server/interface/base_object_classes/actions.py b/weave/trace_server/interface/base_object_classes/actions.py index a67f4676069..cebf4db712d 100644 --- a/weave/trace_server/interface/base_object_classes/actions.py +++ b/weave/trace_server/interface/base_object_classes/actions.py @@ -7,6 +7,7 @@ class LlmJudgeActionSpec(BaseModel): action_type: Literal["llm_judge"] = "llm_judge" + # TODO: Remove this restriction model: Literal["gpt-4o", "gpt-4o-mini"] prompt: str response_format: Optional[dict[str, Any]] From 2b0cf0c58f4f1a8c9f5d8123913449755332f04e Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 16:15:40 -0800 Subject: [PATCH 026/120] Lint merge --- weave/trace/weave_client.py | 2 +- weave/trace_server/calls_query_builder.py | 2 +- weave/trace_server/feedback.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 404c7b2af08..797a2dd76d6 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -327,7 +327,7 @@ def _apply_scorer(self, scorer_op: Op) -> None: score_call_ref = get_ref(score_call) if score_call_ref is None: raise ValueError("Score call has no ref") - return client._add_score( + client._add_score( call_ref_uri=self_ref.uri(), score_name=score_name, score_results=score_results, diff --git a/weave/trace_server/calls_query_builder.py b/weave/trace_server/calls_query_builder.py index 8e13373472d..20fe231a6e7 100644 --- a/weave/trace_server/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder.py @@ -143,7 +143,7 @@ def as_sql( self, pb: ParamBuilder, table_alias: str, - cast: typing.Optional[tsi_query.CastTo] = None, + cast: Optional[tsi_query.CastTo] = None, ) -> str: inner = super().as_sql(pb, "feedback") param_name = pb.add_param(self.feedback_type) diff --git a/weave/trace_server/feedback.py b/weave/trace_server/feedback.py index 72303c42183..05545135195 100644 --- a/weave/trace_server/feedback.py +++ b/weave/trace_server/feedback.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple, Type, TypeVar, Union, overload +from typing import Any, Optional, TypeVar, Union, overload from pydantic import BaseModel, ValidationError @@ -65,12 +65,12 @@ def _ensure_ref_is_valid( @overload def _ensure_ref_is_valid( ref: str, - expected_type: Tuple[Type[T], ...], + expected_type: tuple[type[T], ...], ) -> T: ... def _ensure_ref_is_valid( - ref: str, expected_type: Optional[Tuple[Type, ...]] = None + ref: str, expected_type: Optional[tuple[type, ...]] = None ) -> Union[ri.InternalObjectRef, ri.InternalTableRef, ri.InternalCallRef]: """Validates and parses an internal URI reference. From 95d463bbe91a06b113940a20808ef5f1e334a704 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 16:36:19 -0800 Subject: [PATCH 027/120] Initial Comments --- weave/trace_server/calls_query_builder.py | 1 - weave/trace_server/feedback.py | 65 ++++--------------- weave/trace_server/orm.py | 15 ++--- .../trace_server/refs_internal_server_util.py | 47 ++++++++++++++ weave/trace_server/trace_server_interface.py | 2 + 5 files changed, 70 insertions(+), 60 deletions(-) create mode 100644 weave/trace_server/refs_internal_server_util.py diff --git a/weave/trace_server/calls_query_builder.py b/weave/trace_server/calls_query_builder.py index 20fe231a6e7..abfe82b0887 100644 --- a/weave/trace_server/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder.py @@ -258,7 +258,6 @@ def as_sql(self, pb: ParamBuilder, table_alias: str) -> str: if self._consumed_fields is None: self._consumed_fields = [] for field in conditions.fields_used: - # TODO: Verify that this is "ok" since before we were just looking at field name self._consumed_fields.append(field) return combine_conditions(conditions.conditions, "AND") diff --git a/weave/trace_server/feedback.py b/weave/trace_server/feedback.py index 05545135195..82c17202c98 100644 --- a/weave/trace_server/feedback.py +++ b/weave/trace_server/feedback.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, TypeVar, Union, overload +from typing import Any from pydantic import BaseModel, ValidationError @@ -6,6 +6,7 @@ from weave.trace_server import trace_server_interface as tsi from weave.trace_server.errors import InvalidRequest from weave.trace_server.orm import Column, Table +from weave.trace_server.refs_internal_server_util import ensure_ref_is_valid from weave.trace_server.validation import ( validate_purge_req_multiple, validate_purge_req_one, @@ -51,48 +52,12 @@ class RunnablePayloadSchema(BaseModel): output: Any -T = TypeVar( - "T", ri.InternalObjectRef, ri.InternalTableRef, ri.InternalCallRef, ri.InternalOpRef -) - - -@overload -def _ensure_ref_is_valid( - ref: str, expected_type: None = None -) -> Union[ri.InternalObjectRef, ri.InternalTableRef, ri.InternalCallRef]: ... - - -@overload -def _ensure_ref_is_valid( - ref: str, - expected_type: tuple[type[T], ...], -) -> T: ... +def feedback_type_is_annotation(feedback_type: str) -> bool: + return feedback_type.startswith(ANNOTATION_FEEDBACK_TYPE_PREFIX) -def _ensure_ref_is_valid( - ref: str, expected_type: Optional[tuple[type, ...]] = None -) -> Union[ri.InternalObjectRef, ri.InternalTableRef, ri.InternalCallRef]: - """Validates and parses an internal URI reference. - - Args: - ref: The reference string to validate - expected_type: Optional tuple of expected reference types - - Returns: - The parsed internal reference object - - Raises: - InvalidRequest: If the reference is invalid or doesn't match expected_type - """ - try: - parsed_ref = ri.parse_internal_uri(ref) - except ValueError as e: - raise InvalidRequest(f"Invalid ref: {ref}, {e}") - if expected_type and not isinstance(parsed_ref, expected_type): - raise InvalidRequest( - f"Invalid ref: {ref}, expected {(t.__name__ for t in expected_type)}" - ) - return parsed_ref +def feedback_type_is_runnable(feedback_type: str) -> bool: + return feedback_type.startswith(RUNNABLE_FEEDBACK_TYPE_PREFIX) def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None: @@ -106,9 +71,7 @@ def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None: ) # Validate the required fields for the feedback type. - is_annotation = req.feedback_type.startswith(ANNOTATION_FEEDBACK_TYPE_PREFIX) - is_runnable = req.feedback_type.startswith(RUNNABLE_FEEDBACK_TYPE_PREFIX) - if is_annotation: + if feedback_type_is_annotation(req.feedback_type): if not req.feedback_type.startswith(ANNOTATION_FEEDBACK_TYPE_PREFIX + "."): raise InvalidRequest( f"Invalid annotation feedback type: {req.feedback_type}" @@ -116,7 +79,7 @@ def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None: type_subname = req.feedback_type[len(ANNOTATION_FEEDBACK_TYPE_PREFIX) + 1 :] if not req.annotation_ref: raise InvalidRequest("annotation_ref is required for annotation feedback") - annotation_ref = _ensure_ref_is_valid( + annotation_ref = ensure_ref_is_valid( req.annotation_ref, (ri.InternalObjectRef,) ) if annotation_ref.name != type_subname: @@ -133,13 +96,13 @@ def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None: raise InvalidRequest( "annotation_ref is not allowed for non-annotation feedback" ) - elif is_runnable: + elif feedback_type_is_runnable(req.feedback_type): if not req.feedback_type.startswith(RUNNABLE_FEEDBACK_TYPE_PREFIX + "."): raise InvalidRequest(f"Invalid runnable feedback type: {req.feedback_type}") type_subname = req.feedback_type[len(RUNNABLE_FEEDBACK_TYPE_PREFIX) + 1 :] if not req.runnable_ref: raise InvalidRequest("runnable_ref is required for runnable feedback") - runnable_ref = _ensure_ref_is_valid( + runnable_ref = ensure_ref_is_valid( req.runnable_ref, (ri.InternalOpRef, ri.InternalObjectRef) ) if runnable_ref.name != type_subname: @@ -163,13 +126,13 @@ def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None: # Validate the ref formats (we could even query the DB to ensure they exist and are valid) if req.annotation_ref: - _ensure_ref_is_valid(req.annotation_ref, (ri.InternalObjectRef,)) + ensure_ref_is_valid(req.annotation_ref, (ri.InternalObjectRef,)) if req.runnable_ref: - _ensure_ref_is_valid(req.runnable_ref, (ri.InternalOpRef, ri.InternalObjectRef)) + ensure_ref_is_valid(req.runnable_ref, (ri.InternalOpRef, ri.InternalObjectRef)) if req.call_ref: - _ensure_ref_is_valid(req.call_ref, (ri.InternalCallRef,)) + ensure_ref_is_valid(req.call_ref, (ri.InternalCallRef,)) if req.trigger_ref: - _ensure_ref_is_valid(req.trigger_ref, (ri.InternalObjectRef,)) + ensure_ref_is_valid(req.trigger_ref, (ri.InternalObjectRef,)) MESSAGE_INVALID_FEEDBACK_PURGE = ( diff --git a/weave/trace_server/orm.py b/weave/trace_server/orm.py index dd3c931d7cf..4aac2b1e2e6 100644 --- a/weave/trace_server/orm.py +++ b/weave/trace_server/orm.py @@ -51,17 +51,16 @@ def __init__( self._param_to_name: dict[typing.Any, str] = {} def add_param(self, param_value: typing.Any) -> str: - try: + param_name = self._prefix + str(len(self._params)) + + # Only attempt caching for hashable values + if isinstance(param_value, typing.Hashable): if param_value in self._param_to_name: return self._param_to_name[param_value] - except TypeError: - pass - param_name = self._prefix + str(len(self._params)) - self._params[param_name] = param_value - try: self._param_to_name[param_value] = param_name - except TypeError: - pass + + # For non-hashable values, just generate a new param without caching + self._params[param_name] = param_value return param_name def add( diff --git a/weave/trace_server/refs_internal_server_util.py b/weave/trace_server/refs_internal_server_util.py new file mode 100644 index 00000000000..447f43f2bdf --- /dev/null +++ b/weave/trace_server/refs_internal_server_util.py @@ -0,0 +1,47 @@ +from typing import Optional, TypeVar, Union, overload + +from weave.trace_server import refs_internal as ri +from weave.trace_server.errors import InvalidRequest + +T = TypeVar( + "T", ri.InternalObjectRef, ri.InternalTableRef, ri.InternalCallRef, ri.InternalOpRef +) + + +@overload +def ensure_ref_is_valid( + ref: str, expected_type: None = None +) -> Union[ri.InternalObjectRef, ri.InternalTableRef, ri.InternalCallRef]: ... + + +@overload +def ensure_ref_is_valid( + ref: str, + expected_type: tuple[type[T], ...], +) -> T: ... + + +def ensure_ref_is_valid( + ref: str, expected_type: Optional[tuple[type, ...]] = None +) -> Union[ri.InternalObjectRef, ri.InternalTableRef, ri.InternalCallRef]: + """Validates and parses an internal URI reference. + + Args: + ref: The reference string to validate + expected_type: Optional tuple of expected reference types + + Returns: + The parsed internal reference object + + Raises: + InvalidRequest: If the reference is invalid or doesn't match expected_type + """ + try: + parsed_ref = ri.parse_internal_uri(ref) + except ValueError as e: + raise InvalidRequest(f"Invalid ref: {ref}, {e}") + if expected_type and not isinstance(parsed_ref, expected_type): + raise InvalidRequest( + f"Invalid ref: {ref}, expected {(t.__name__ for t in expected_type)}" + ) + return parsed_ref diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index c4a4f0ea25b..37c49c2cd07 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -690,6 +690,8 @@ class FeedbackCreateReq(BaseModel): } ] ) + # TODO: From Griffin: `it would be nice if we could type this to a kind of ref, + # like objectRef, with a pydantic validator and then check its construction in the client.` annotation_ref: Optional[str] = Field( default=None, examples=["weave:///entity/project/object/name:digest"] ) From a6dc197bb01b010505ace78033ac8fa3861796ef Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 16:45:50 -0800 Subject: [PATCH 028/120] Moved Feedback Symbols --- weave/trace_server/feedback.py | 42 +++++-------------- .../trace_server/interface/feedback_types.py | 40 ++++++++++++++++++ weave/trace_server/trace_server_interface.py | 8 ---- 3 files changed, 50 insertions(+), 40 deletions(-) create mode 100644 weave/trace_server/interface/feedback_types.py diff --git a/weave/trace_server/feedback.py b/weave/trace_server/feedback.py index 82c17202c98..7efedcd8b8c 100644 --- a/weave/trace_server/feedback.py +++ b/weave/trace_server/feedback.py @@ -1,10 +1,17 @@ -from typing import Any - -from pydantic import BaseModel, ValidationError +from pydantic import ValidationError from weave.trace_server import refs_internal as ri from weave.trace_server import trace_server_interface as tsi from weave.trace_server.errors import InvalidRequest +from weave.trace_server.interface.feedback_types import ( + ANNOTATION_FEEDBACK_TYPE_PREFIX, + FEEDBACK_PAYLOAD_SCHEMAS, + RUNNABLE_FEEDBACK_TYPE_PREFIX, + AnnotationPayloadSchema, + RunnablePayloadSchema, + feedback_type_is_annotation, + feedback_type_is_runnable, +) from weave.trace_server.orm import Column, Table from weave.trace_server.refs_internal_server_util import ensure_ref_is_valid from weave.trace_server.validation import ( @@ -31,35 +38,6 @@ ) -FEEDBACK_PAYLOAD_SCHEMAS: dict[str, type[BaseModel]] = { - "wandb.reaction.1": tsi.FeedbackPayloadReactionReq, - "wandb.note.1": tsi.FeedbackPayloadNoteReq, -} - -ANNOTATION_FEEDBACK_TYPE_PREFIX = "wandb.annotation" -RUNNABLE_FEEDBACK_TYPE_PREFIX = "wandb.runnable" - - -# Making the decision to use `value` & `payload` as nested keys so that -# we can: -# 1. Add more fields in the future without breaking changes -# 2. Support primitive values for annotation feedback that still schema -class AnnotationPayloadSchema(BaseModel): - value: Any - - -class RunnablePayloadSchema(BaseModel): - output: Any - - -def feedback_type_is_annotation(feedback_type: str) -> bool: - return feedback_type.startswith(ANNOTATION_FEEDBACK_TYPE_PREFIX) - - -def feedback_type_is_runnable(feedback_type: str) -> bool: - return feedback_type.startswith(RUNNABLE_FEEDBACK_TYPE_PREFIX) - - def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None: payload_schema = FEEDBACK_PAYLOAD_SCHEMAS.get(req.feedback_type) if payload_schema: diff --git a/weave/trace_server/interface/feedback_types.py b/weave/trace_server/interface/feedback_types.py new file mode 100644 index 00000000000..673174d4693 --- /dev/null +++ b/weave/trace_server/interface/feedback_types.py @@ -0,0 +1,40 @@ +from typing import Any + +from pydantic import BaseModel, Field + + +class FeedbackPayloadReactionReq(BaseModel): + emoji: str + + +class FeedbackPayloadNoteReq(BaseModel): + note: str = Field(min_length=1, max_length=1024) + + +FEEDBACK_PAYLOAD_SCHEMAS: dict[str, type[BaseModel]] = { + "wandb.reaction.1": FeedbackPayloadReactionReq, + "wandb.note.1": FeedbackPayloadNoteReq, +} + +ANNOTATION_FEEDBACK_TYPE_PREFIX = "wandb.annotation" +RUNNABLE_FEEDBACK_TYPE_PREFIX = "wandb.runnable" + + +# Making the decision to use `value` & `payload` as nested keys so that +# we can: +# 1. Add more fields in the future without breaking changes +# 2. Support primitive values for annotation feedback that still schema +class AnnotationPayloadSchema(BaseModel): + value: Any + + +class RunnablePayloadSchema(BaseModel): + output: Any + + +def feedback_type_is_annotation(feedback_type: str) -> bool: + return feedback_type.startswith(ANNOTATION_FEEDBACK_TYPE_PREFIX) + + +def feedback_type_is_runnable(feedback_type: str) -> bool: + return feedback_type.startswith(RUNNABLE_FEEDBACK_TYPE_PREFIX) diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 37c49c2cd07..a437ac1639e 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -670,14 +670,6 @@ class RefsReadBatchRes(BaseModel): vals: list[Any] -class FeedbackPayloadReactionReq(BaseModel): - emoji: str - - -class FeedbackPayloadNoteReq(BaseModel): - note: str = Field(min_length=1, max_length=1024) - - class FeedbackCreateReq(BaseModel): project_id: str = Field(examples=["entity/project"]) weave_ref: str = Field(examples=["weave:///entity/project/object/name:digest"]) From 33ad7d12c8ef5130a2c25b70afc646911d85373f Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 16:50:39 -0800 Subject: [PATCH 029/120] Merged and linted --- weave/actions_worker/tasks.py | 6 +++--- weave/trace_server/interface/base_object_classes/actions.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/weave/actions_worker/tasks.py b/weave/actions_worker/tasks.py index a2a00408bdb..c165a2aec49 100644 --- a/weave/actions_worker/tasks.py +++ b/weave/actions_worker/tasks.py @@ -1,14 +1,14 @@ import json from functools import partial -from typing import Any, Tuple +from typing import Any from openai import OpenAI -from weave.trace_server.feedback import RunnablePayloadSchema from weave.trace_server.interface.base_object_classes.actions import ( ActionDefinition, LlmJudgeActionSpec, ) +from weave.trace_server.interface.feedback_types import RunnablePayloadSchema from weave.trace_server.refs_internal import ( InternalCallRef, InternalObjectRef, @@ -91,7 +91,7 @@ def do_llm_judge_action(config: LlmJudgeActionSpec, call: CallSchema) -> Any: def do_action( action_ref: str, action_config: ActionDefinition, call: CallSchema -) -> Tuple[Any, FeedbackCreateReq]: +) -> tuple[Any, FeedbackCreateReq]: runnable_ref = None if isinstance(action_config, LlmJudgeActionSpec): result = do_llm_judge_action(action_config, call) diff --git a/weave/trace_server/interface/base_object_classes/actions.py b/weave/trace_server/interface/base_object_classes/actions.py index cebf4db712d..1bc465427d0 100644 --- a/weave/trace_server/interface/base_object_classes/actions.py +++ b/weave/trace_server/interface/base_object_classes/actions.py @@ -36,5 +36,6 @@ class NoopActionSpec(BaseModel): # TODO: Make sure we really like this name - it is permanent class ActionDefinition(base_object_def.BaseObject): - name: str + # Pyright doesn't like this override + # name: str spec: ActionSpecType From 0c4fc7a4cecb38affd5580980fc9864adfb1116d Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 17:29:22 -0800 Subject: [PATCH 030/120] Implemented initial action dispatching --- .../{tasks.py => actions/llm_judge.py} | 49 +------ weave/actions_worker/dispatcher.py | 122 ++++++++++++++++++ .../clickhouse_trace_server_batched.py | 13 +- .../interface/base_object_classes/actions.py | 10 -- .../base_object_classes/contains_words.py | 17 +++ 5 files changed, 155 insertions(+), 56 deletions(-) rename weave/actions_worker/{tasks.py => actions/llm_judge.py} (50%) create mode 100644 weave/actions_worker/dispatcher.py create mode 100644 weave/trace_server/interface/base_object_classes/contains_words.py diff --git a/weave/actions_worker/tasks.py b/weave/actions_worker/actions/llm_judge.py similarity index 50% rename from weave/actions_worker/tasks.py rename to weave/actions_worker/actions/llm_judge.py index c165a2aec49..34f69fe7415 100644 --- a/weave/actions_worker/tasks.py +++ b/weave/actions_worker/actions/llm_judge.py @@ -5,45 +5,17 @@ from openai import OpenAI from weave.trace_server.interface.base_object_classes.actions import ( - ActionDefinition, LlmJudgeActionSpec, ) -from weave.trace_server.interface.feedback_types import RunnablePayloadSchema -from weave.trace_server.refs_internal import ( - InternalCallRef, - InternalObjectRef, - InternalOpRef, - parse_internal_uri, -) from weave.trace_server.trace_server_interface import ( CallSchema, - FeedbackCreateReq, + TraceServerInterface, ) -def publish_results_as_feedback( - target_call: CallSchema, - runnable_ref: str, - result: Any, -) -> FeedbackCreateReq: - project_id = target_call.project_id - call_id = target_call.id - weave_ref = InternalCallRef(project_id, call_id).uri() - parsed_action_ref = parse_internal_uri(runnable_ref) - if not isinstance(parsed_action_ref, (InternalObjectRef, InternalOpRef)): - raise ValueError(f"Invalid action ref: {runnable_ref}") - action_name = parsed_action_ref.name - - return FeedbackCreateReq( - project_id=project_id, - weave_ref=weave_ref, - feedback_type="wandb.runnable." + action_name, - runnable_ref=runnable_ref, - payload=RunnablePayloadSchema(output=result).model_dump(), - ) - - -def do_llm_judge_action(config: LlmJudgeActionSpec, call: CallSchema) -> Any: +def do_llm_judge_action( + config: LlmJudgeActionSpec, call: CallSchema, trace_server: TraceServerInterface +) -> Any: model = config.model system_prompt = config.prompt if config.response_format is None: @@ -87,16 +59,3 @@ def do_llm_judge_action(config: LlmJudgeActionSpec, call: CallSchema) -> Any: if response_is_not_object: res = res[dummy_key] return res - - -def do_action( - action_ref: str, action_config: ActionDefinition, call: CallSchema -) -> tuple[Any, FeedbackCreateReq]: - runnable_ref = None - if isinstance(action_config, LlmJudgeActionSpec): - result = do_llm_judge_action(action_config, call) - runnable_ref = action_ref - else: - raise ValueError(f"Unsupported action config: {action_config}") - req = publish_results_as_feedback(call, runnable_ref, result) - return result, req diff --git a/weave/actions_worker/dispatcher.py b/weave/actions_worker/dispatcher.py new file mode 100644 index 00000000000..9ca39402923 --- /dev/null +++ b/weave/actions_worker/dispatcher.py @@ -0,0 +1,122 @@ +from typing import Any, Callable + +from pydantic import BaseModel + +from weave.actions_worker.actions.llm_judge import do_llm_judge_action +from weave.trace_server.interface.base_object_classes.actions import ( + ActionDefinition, + ContainsWordsActionSpec, + LlmJudgeActionSpec, +) +from weave.trace_server.interface.base_object_classes.contains_words import ( + do_contains_words_action, +) +from weave.trace_server.interface.feedback_types import RunnablePayloadSchema +from weave.trace_server.refs_internal import ( + InternalCallRef, + InternalObjectRef, + InternalOpRef, + parse_internal_uri, +) +from weave.trace_server.trace_server_interface import ( + ActionsExecuteBatchReq, + CallSchema, + CallsFilter, + CallsQueryReq, + FeedbackCreateReq, + FeedbackCreateRes, + ObjReadReq, + TraceServerInterface, +) + +ActionFnType = Callable[[ActionDefinition, CallSchema, TraceServerInterface], Any] + +# TODO: Nail down this typing +dispatch_map: dict[str, ActionFnType] = { + LlmJudgeActionSpec.action_type: do_llm_judge_action, + ContainsWordsActionSpec.action_type: do_contains_words_action, +} + + +class ActionResult(BaseModel): + result: Any + feedback_res: FeedbackCreateRes + + +def execute_batch( + batch_req: ActionsExecuteBatchReq, + trace_server: TraceServerInterface, +) -> list[ActionResult]: + # 1. Lookup the action definition + parsed_ref = parse_internal_uri(batch_req.action_ref) + if parsed_ref.project_id != batch_req.project_id: + raise ValueError( + f"Action ref {batch_req.action_ref} does not match project_id {batch_req.project_id}" + ) + if not isinstance(parsed_ref, InternalObjectRef): + raise ValueError(f"Action ref {batch_req.action_ref} is not an object ref") + + action_def_read = trace_server.obj_read( + ObjReadReq( + project_id=batch_req.project_id, + object_id=parsed_ref.name, + digest=parsed_ref.version, + ) + ) + action_def = ActionDefinition.model_validate(action_def_read.obj.val) + + # Lookup the calls + calls_query = trace_server.calls_query( + CallsQueryReq( + project_id=batch_req.project_id, + filter=CallsFilter(call_ids=batch_req.call_ids), + ) + ) + calls = calls_query.calls + + # 2. Dispatch the action to each call + # TODO: Some actions may be able to be batched together + results = [] + for call in calls: + result = dispatch_action(batch_req.action_ref, action_def, call, trace_server) + results.append(result) + return results + + +def dispatch_action( + action_ref: str, + action_def: ActionDefinition, + target_call: CallSchema, + trace_server: TraceServerInterface, +) -> ActionResult: + action_type = action_def.spec.action_type + action_fn = dispatch_map[action_type] + result = action_fn(action_def.spec, target_call, trace_server) + feedback_res = publish_results_as_feedback( + target_call, action_ref, result, trace_server + ) + return ActionResult(result=result, feedback_res=feedback_res) + + +def publish_results_as_feedback( + target_call: CallSchema, + action_ref: str, + result: Any, + trace_server: TraceServerInterface, +) -> FeedbackCreateRes: + project_id = target_call.project_id + call_id = target_call.id + weave_ref = InternalCallRef(project_id, call_id).uri() + parsed_action_ref = parse_internal_uri(action_ref) + if not isinstance(parsed_action_ref, (InternalObjectRef, InternalOpRef)): + raise ValueError(f"Invalid action ref: {action_ref}") + action_name = parsed_action_ref.name + return trace_server.feedback_create( + FeedbackCreateReq( + project_id=project_id, + weave_ref=weave_ref, + feedback_type="wandb.runnable." + action_name, + runnable_ref=action_ref, + payload=RunnablePayloadSchema(output=result).model_dump(), + ) + ) diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 55fb474e29d..d39b6f5d575 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -46,6 +46,7 @@ from clickhouse_connect.driver.query import QueryResult from clickhouse_connect.driver.summary import QuerySummary +from weave.actions_worker.dispatcher import execute_batch from weave.trace_server import clickhouse_trace_server_migrator as wf_migrator from weave.trace_server import environment as wf_env from weave.trace_server import refs_internal as ri @@ -1410,7 +1411,17 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: def actions_execute_batch( self, req: tsi.ActionsExecuteBatchReq ) -> tsi.ActionsExecuteBatchRes: - raise NotImplementedError() + if len(req.call_ids) == 0: + return tsi.ActionsExecuteBatchRes() + if len(req.call_ids) > 1: + raise NotImplementedError("Batching actions is not yet supported") + + execute_batch( + batch_req=req, + trace_server=self, + ) + + raise tsi.ActionsExecuteBatchRes() def completions_create( self, req: tsi.CompletionsCreateReq diff --git a/weave/trace_server/interface/base_object_classes/actions.py b/weave/trace_server/interface/base_object_classes/actions.py index 1bc465427d0..ac0a19ba9d4 100644 --- a/weave/trace_server/interface/base_object_classes/actions.py +++ b/weave/trace_server/interface/base_object_classes/actions.py @@ -18,19 +18,9 @@ class ContainsWordsActionSpec(BaseModel): target_words: list[str] -class WordCountActionSpec(BaseModel): - action_type: Literal["wordcount"] = "wordcount" - - -class NoopActionSpec(BaseModel): - action_type: Literal["noop"] = "noop" - - ActionSpecType = Union[ LlmJudgeActionSpec, ContainsWordsActionSpec, - WordCountActionSpec, - NoopActionSpec, ] diff --git a/weave/trace_server/interface/base_object_classes/contains_words.py b/weave/trace_server/interface/base_object_classes/contains_words.py new file mode 100644 index 00000000000..9ae71e6f5f9 --- /dev/null +++ b/weave/trace_server/interface/base_object_classes/contains_words.py @@ -0,0 +1,17 @@ +from typing import Any + +from weave.trace_server.interface.base_object_classes.actions import ( + ContainsWordsActionSpec, +) +from weave.trace_server.trace_server_interface import ( + CallSchema, + TraceServerInterface, +) + + +def do_contains_words_action( + config: ContainsWordsActionSpec, + call: CallSchema, + trace_server: TraceServerInterface, +) -> Any: + pass From 0c5be92acc58dfae52592161d4fb432096bb3226 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 17:32:17 -0800 Subject: [PATCH 031/120] Implemented contains words --- .../interface/base_object_classes/contains_words.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/weave/trace_server/interface/base_object_classes/contains_words.py b/weave/trace_server/interface/base_object_classes/contains_words.py index 9ae71e6f5f9..0ea88c8e024 100644 --- a/weave/trace_server/interface/base_object_classes/contains_words.py +++ b/weave/trace_server/interface/base_object_classes/contains_words.py @@ -1,3 +1,4 @@ +import json from typing import Any from weave.trace_server.interface.base_object_classes.actions import ( @@ -14,4 +15,9 @@ def do_contains_words_action( call: CallSchema, trace_server: TraceServerInterface, ) -> Any: - pass + target_words = config.target_words + text = json.dumps(call.outputs) + for word in target_words: + if word in text: + return True + return False From b4af5c6b2f5138a6ee30934aeefff7f9e5c81665 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 17:33:27 -0800 Subject: [PATCH 032/120] Correct location --- .../actions}/contains_words.py | 0 weave/actions_worker/dispatcher.py | 6 +++--- 2 files changed, 3 insertions(+), 3 deletions(-) rename weave/{trace_server/interface/base_object_classes => actions_worker/actions}/contains_words.py (100%) diff --git a/weave/trace_server/interface/base_object_classes/contains_words.py b/weave/actions_worker/actions/contains_words.py similarity index 100% rename from weave/trace_server/interface/base_object_classes/contains_words.py rename to weave/actions_worker/actions/contains_words.py diff --git a/weave/actions_worker/dispatcher.py b/weave/actions_worker/dispatcher.py index 9ca39402923..5f80cfb7246 100644 --- a/weave/actions_worker/dispatcher.py +++ b/weave/actions_worker/dispatcher.py @@ -2,15 +2,15 @@ from pydantic import BaseModel +from weave.actions_worker.actions.contains_words import ( + do_contains_words_action, +) from weave.actions_worker.actions.llm_judge import do_llm_judge_action from weave.trace_server.interface.base_object_classes.actions import ( ActionDefinition, ContainsWordsActionSpec, LlmJudgeActionSpec, ) -from weave.trace_server.interface.base_object_classes.contains_words import ( - do_contains_words_action, -) from weave.trace_server.interface.feedback_types import RunnablePayloadSchema from weave.trace_server.refs_internal import ( InternalCallRef, From 712cbb70f66db6c32aa401c6d363e751c4c753c8 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 17:44:46 -0800 Subject: [PATCH 033/120] contains words implemented --- tests/trace/test_actions_lifecycle.py | 4 ++-- weave/actions_worker/actions/contains_words.py | 2 +- weave/actions_worker/dispatcher.py | 11 +++++++---- weave/trace_server/clickhouse_trace_server_batched.py | 2 +- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/trace/test_actions_lifecycle.py b/tests/trace/test_actions_lifecycle.py index 39b344d4054..01bc65aeeff 100644 --- a/tests/trace/test_actions_lifecycle.py +++ b/tests/trace/test_actions_lifecycle.py @@ -34,7 +34,7 @@ def test_action_lifecycle_simple(client: WeaveClient): # Part 2: Demonstrate manual feedback (this is not user-facing) @weave.op def example_op(input: str) -> str: - return input[::-1] + return input + "!!!" _, call1 = example_op.call("i've been very distracted today") @@ -61,7 +61,7 @@ def example_op(input: str) -> str: # Step 3: test that we can in-place execute one action at a time. - _, call2 = example_op.call("i've been very meditative today") + _, call2 = example_op.call("i've been very mindful today") res = client.server.actions_execute_batch( ActionsExecuteBatchReq.model_validate( diff --git a/weave/actions_worker/actions/contains_words.py b/weave/actions_worker/actions/contains_words.py index 0ea88c8e024..1d539b7d61f 100644 --- a/weave/actions_worker/actions/contains_words.py +++ b/weave/actions_worker/actions/contains_words.py @@ -16,7 +16,7 @@ def do_contains_words_action( trace_server: TraceServerInterface, ) -> Any: target_words = config.target_words - text = json.dumps(call.outputs) + text = json.dumps(call.output) for word in target_words: if word in text: return True diff --git a/weave/actions_worker/dispatcher.py b/weave/actions_worker/dispatcher.py index 5f80cfb7246..d6e92099651 100644 --- a/weave/actions_worker/dispatcher.py +++ b/weave/actions_worker/dispatcher.py @@ -8,6 +8,7 @@ from weave.actions_worker.actions.llm_judge import do_llm_judge_action from weave.trace_server.interface.base_object_classes.actions import ( ActionDefinition, + ActionSpecType, ContainsWordsActionSpec, LlmJudgeActionSpec, ) @@ -32,9 +33,9 @@ ActionFnType = Callable[[ActionDefinition, CallSchema, TraceServerInterface], Any] # TODO: Nail down this typing -dispatch_map: dict[str, ActionFnType] = { - LlmJudgeActionSpec.action_type: do_llm_judge_action, - ContainsWordsActionSpec.action_type: do_contains_words_action, +dispatch_map: dict[type[ActionSpecType], ActionFnType] = { + LlmJudgeActionSpec: do_llm_judge_action, + ContainsWordsActionSpec: do_contains_words_action, } @@ -89,7 +90,7 @@ def dispatch_action( target_call: CallSchema, trace_server: TraceServerInterface, ) -> ActionResult: - action_type = action_def.spec.action_type + action_type = type(action_def.spec) action_fn = dispatch_map[action_type] result = action_fn(action_def.spec, target_call, trace_server) feedback_res = publish_results_as_feedback( @@ -118,5 +119,7 @@ def publish_results_as_feedback( feedback_type="wandb.runnable." + action_name, runnable_ref=action_ref, payload=RunnablePayloadSchema(output=result).model_dump(), + # TODO: Make `wb_user_id` optional. + wb_user_id="", ) ) diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index d39b6f5d575..4d87730846c 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -1421,7 +1421,7 @@ def actions_execute_batch( trace_server=self, ) - raise tsi.ActionsExecuteBatchRes() + return tsi.ActionsExecuteBatchRes() def completions_create( self, req: tsi.CompletionsCreateReq From 212f3b4d817ca2dcb96b27423c682396fb6898dc Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 17:57:17 -0800 Subject: [PATCH 034/120] Implemented lmm judge --- .../actions_worker/actions/contains_words.py | 1 + weave/actions_worker/actions/llm_judge.py | 62 ++++++++++++------- weave/actions_worker/dispatcher.py | 18 +++--- .../interface/base_object_classes/actions.py | 5 +- 4 files changed, 56 insertions(+), 30 deletions(-) diff --git a/weave/actions_worker/actions/contains_words.py b/weave/actions_worker/actions/contains_words.py index 1d539b7d61f..c15813930a7 100644 --- a/weave/actions_worker/actions/contains_words.py +++ b/weave/actions_worker/actions/contains_words.py @@ -11,6 +11,7 @@ def do_contains_words_action( + project_id: str, config: ContainsWordsActionSpec, call: CallSchema, trace_server: TraceServerInterface, diff --git a/weave/actions_worker/actions/llm_judge.py b/weave/actions_worker/actions/llm_judge.py index 34f69fe7415..c8b37efb1cc 100644 --- a/weave/actions_worker/actions/llm_judge.py +++ b/weave/actions_worker/actions/llm_judge.py @@ -1,25 +1,24 @@ import json -from functools import partial from typing import Any -from openai import OpenAI - from weave.trace_server.interface.base_object_classes.actions import ( LlmJudgeActionSpec, ) from weave.trace_server.trace_server_interface import ( CallSchema, + CompletionsCreateReq, TraceServerInterface, ) def do_llm_judge_action( - config: LlmJudgeActionSpec, call: CallSchema, trace_server: TraceServerInterface + project_id: str, + config: LlmJudgeActionSpec, + call: CallSchema, + trace_server: TraceServerInterface, ) -> Any: model = config.model system_prompt = config.prompt - if config.response_format is None: - raise ValueError("response_format is required for llm_judge") response_is_not_object = config.response_format["type"] != "object" dummy_key = "response" @@ -42,20 +41,41 @@ def do_llm_judge_action( "output": call.output, } - client = OpenAI() - # Silly hack to get around issue in tests: - create = client.chat.completions.create - if hasattr(create, "resolve_fn"): - create = partial(create.resolve_fn, self=client.chat.completions) - completion = create( - model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": json.dumps(args)}, - ], - response_format=response_format, + completion = trace_server.completions_create( + CompletionsCreateReq( + project_id=project_id, + inputs=dict( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": json.dumps(args)}, + ], + response_format=response_format, + ), + track_llm_call=False, + ) ) - res = json.loads(completion.choices[0].message.content) - if response_is_not_object: - res = res[dummy_key] + + # client = OpenAI() + # # Silly hack to get around issue in tests: + # create = client.chat.completions.create + # if hasattr(create, "resolve_fn"): + # create = partial(create.resolve_fn, self=client.chat.completions) + # completion = create( + # model=model, + # messages=[ + # {"role": "system", "content": system_prompt}, + # {"role": "user", "content": json.dumps(args)}, + # ], + # response_format=response_format, + # ) + content = ( + completion.response.get("choices", [{}])[0].get("message", {}).get("content") + ) + if content is None: + res = None + else: + res = json.loads(content) + if response_is_not_object: + res = res[dummy_key] return res diff --git a/weave/actions_worker/dispatcher.py b/weave/actions_worker/dispatcher.py index d6e92099651..ed935678b2c 100644 --- a/weave/actions_worker/dispatcher.py +++ b/weave/actions_worker/dispatcher.py @@ -30,7 +30,7 @@ TraceServerInterface, ) -ActionFnType = Callable[[ActionDefinition, CallSchema, TraceServerInterface], Any] +ActionFnType = Callable[[str, ActionDefinition, CallSchema, TraceServerInterface], Any] # TODO: Nail down this typing dispatch_map: dict[type[ActionSpecType], ActionFnType] = { @@ -48,18 +48,19 @@ def execute_batch( batch_req: ActionsExecuteBatchReq, trace_server: TraceServerInterface, ) -> list[ActionResult]: + project_id = batch_req.project_id # 1. Lookup the action definition parsed_ref = parse_internal_uri(batch_req.action_ref) - if parsed_ref.project_id != batch_req.project_id: + if parsed_ref.project_id != project_id: raise ValueError( - f"Action ref {batch_req.action_ref} does not match project_id {batch_req.project_id}" + f"Action ref {batch_req.action_ref} does not match project_id {project_id}" ) if not isinstance(parsed_ref, InternalObjectRef): raise ValueError(f"Action ref {batch_req.action_ref} is not an object ref") action_def_read = trace_server.obj_read( ObjReadReq( - project_id=batch_req.project_id, + project_id=project_id, object_id=parsed_ref.name, digest=parsed_ref.version, ) @@ -69,7 +70,7 @@ def execute_batch( # Lookup the calls calls_query = trace_server.calls_query( CallsQueryReq( - project_id=batch_req.project_id, + project_id=project_id, filter=CallsFilter(call_ids=batch_req.call_ids), ) ) @@ -79,12 +80,15 @@ def execute_batch( # TODO: Some actions may be able to be batched together results = [] for call in calls: - result = dispatch_action(batch_req.action_ref, action_def, call, trace_server) + result = dispatch_action( + project_id, batch_req.action_ref, action_def, call, trace_server + ) results.append(result) return results def dispatch_action( + project_id: str, action_ref: str, action_def: ActionDefinition, target_call: CallSchema, @@ -92,7 +96,7 @@ def dispatch_action( ) -> ActionResult: action_type = type(action_def.spec) action_fn = dispatch_map[action_type] - result = action_fn(action_def.spec, target_call, trace_server) + result = action_fn(project_id, action_def.spec, target_call, trace_server) feedback_res = publish_results_as_feedback( target_call, action_ref, result, trace_server ) diff --git a/weave/trace_server/interface/base_object_classes/actions.py b/weave/trace_server/interface/base_object_classes/actions.py index ac0a19ba9d4..fd78a654492 100644 --- a/weave/trace_server/interface/base_object_classes/actions.py +++ b/weave/trace_server/interface/base_object_classes/actions.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Union from pydantic import BaseModel @@ -10,7 +10,8 @@ class LlmJudgeActionSpec(BaseModel): # TODO: Remove this restriction model: Literal["gpt-4o", "gpt-4o-mini"] prompt: str - response_format: Optional[dict[str, Any]] + # Expected to be valid JSON Schema + response_format: dict[str, Any] class ContainsWordsActionSpec(BaseModel): From d6655e7e775335de048cfdd8b29aac18692e63cc Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 18:14:51 -0800 Subject: [PATCH 035/120] Finished basic implementation of LLM judge --- tests/trace/test_actions_lifecycle.py | 37 ++++++++++++++++++--------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/tests/trace/test_actions_lifecycle.py b/tests/trace/test_actions_lifecycle.py index 01bc65aeeff..28147b792c2 100644 --- a/tests/trace/test_actions_lifecycle.py +++ b/tests/trace/test_actions_lifecycle.py @@ -1,6 +1,11 @@ +import os + import pytest import weave +from tests.integrations.litellm.client_completions_create_test import ( + secret_fetcher_context, +) from tests.trace.util import client_is_sqlite from weave.trace.weave_client import WeaveClient from weave.trace_server.interface.base_object_classes.actions import ( @@ -12,7 +17,14 @@ ) -def test_action_lifecycle_simple(client: WeaveClient): +class DummySecretFetcher: + def fetch(self, secret_name: str) -> dict: + return { + "secrets": {secret_name: os.environ.get(secret_name, "DUMMY_SECRET_VALUE")} + } + + +def test_action_lifecycle_word_count(client: WeaveClient): if client_is_sqlite(client): return pytest.skip("skipping for sqlite") @@ -104,20 +116,21 @@ def test_action_lifecycle_llm_judge(client: WeaveClient): @weave.op def example_op(input: str) -> str: - return input[::-1] + return input + "." # Step 2: test that we can in-place execute one action at a time. - _, call = example_op.call("i've been very meditative today") - - res = client.server.actions_execute_batch( - ActionsExecuteBatchReq.model_validate( - { - "project_id": client._project_id(), - "action_ref": action_ref_uri, - "call_ids": [call.id], - } + _, call = example_op.call("i've been very meditative and mindful today") + + with secret_fetcher_context(DummySecretFetcher()): + res = client.server.actions_execute_batch( + ActionsExecuteBatchReq.model_validate( + { + "project_id": client._project_id(), + "action_ref": action_ref_uri, + "call_ids": [call.id], + } + ) ) - ) feedbacks = list(call.feedback) assert len(feedbacks) == 1 From 7649ef4d748fdd23edc25a33245439a737d711e2 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 18:22:46 -0800 Subject: [PATCH 036/120] All basic functionality working - still some todos, but the base is there --- tests/trace/test_actions_lifecycle.py | 54 +++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/tests/trace/test_actions_lifecycle.py b/tests/trace/test_actions_lifecycle.py index 28147b792c2..f7762833aee 100644 --- a/tests/trace/test_actions_lifecycle.py +++ b/tests/trace/test_actions_lifecycle.py @@ -1,6 +1,8 @@ import os +from unittest.mock import patch import pytest +from litellm.assistants.main import ModelResponse import weave from tests.integrations.litellm.client_completions_create_test import ( @@ -93,6 +95,40 @@ def example_op(input: str) -> str: assert feedback.payload == {"output": True} +mock_response = { + "id": "chatcmpl-AQPvs3DE4NQqLxorvaTPixpqq9nTD", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": '{"response":true}', + "role": "assistant", + "tool_calls": None, + "function_call": None, + }, + } + ], + "created": 1730859576, + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "system_fingerprint": "fp_0ba0d124f1", + "usage": { + "completion_tokens": 5, + "prompt_tokens": 74, + "total_tokens": 79, + "completion_tokens_details": { + "audio_tokens": 0, + "reasoning_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, + "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, + }, + "service_tier": None, +} + + def test_action_lifecycle_llm_judge(client: WeaveClient): if client_is_sqlite(client): return pytest.skip("skipping for sqlite") @@ -122,15 +158,17 @@ def example_op(input: str) -> str: _, call = example_op.call("i've been very meditative and mindful today") with secret_fetcher_context(DummySecretFetcher()): - res = client.server.actions_execute_batch( - ActionsExecuteBatchReq.model_validate( - { - "project_id": client._project_id(), - "action_ref": action_ref_uri, - "call_ids": [call.id], - } + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse.model_validate(mock_response) + client.server.actions_execute_batch( + ActionsExecuteBatchReq.model_validate( + { + "project_id": client._project_id(), + "action_ref": action_ref_uri, + "call_ids": [call.id], + } + ) ) - ) feedbacks = list(call.feedback) assert len(feedbacks) == 1 From 406333a4fe8932eddc6166cd3f82219b6822f7bb Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 18:46:54 -0800 Subject: [PATCH 037/120] More structured tests --- tests/trace/test_actions_lifecycle.py | 102 +++++++++++++++++++++- weave/actions_worker/actions/llm_judge.py | 13 --- 2 files changed, 99 insertions(+), 16 deletions(-) diff --git a/tests/trace/test_actions_lifecycle.py b/tests/trace/test_actions_lifecycle.py index f7762833aee..b26d2cb86f0 100644 --- a/tests/trace/test_actions_lifecycle.py +++ b/tests/trace/test_actions_lifecycle.py @@ -95,7 +95,7 @@ def example_op(input: str) -> str: assert feedback.payload == {"output": True} -mock_response = { +primitive_mock_response = { "id": "chatcmpl-AQPvs3DE4NQqLxorvaTPixpqq9nTD", "choices": [ { @@ -129,7 +129,7 @@ def example_op(input: str) -> str: } -def test_action_lifecycle_llm_judge(client: WeaveClient): +def test_action_lifecycle_llm_judge_primitive(client: WeaveClient): if client_is_sqlite(client): return pytest.skip("skipping for sqlite") @@ -159,7 +159,9 @@ def example_op(input: str) -> str: with secret_fetcher_context(DummySecretFetcher()): with patch("litellm.completion") as mock_completion: - mock_completion.return_value = ModelResponse.model_validate(mock_response) + mock_completion.return_value = ModelResponse.model_validate( + primitive_mock_response + ) client.server.actions_execute_batch( ActionsExecuteBatchReq.model_validate( { @@ -176,3 +178,97 @@ def example_op(input: str) -> str: assert feedback.feedback_type == "wandb.runnable." + action_name assert feedback.runnable_ref == action_ref_uri assert feedback.payload == {"output": True} + + +structured_mock_response = { + "id": "chatcmpl-AQQKJWQDxSvU2Ya9ool2vgcJrFuON", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": '{"is_mindful":true,"reason":"The response reflects a state of being that embodies mindfulness and meditation, acknowledging a positive mental and emotional state."}', + "role": "assistant", + "tool_calls": None, + "function_call": None, + }, + } + ], + "created": 1730861091, + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "system_fingerprint": "fp_0ba0d124f1", + "usage": { + "completion_tokens": 32, + "prompt_tokens": 84, + "total_tokens": 116, + "completion_tokens_details": { + "audio_tokens": 0, + "reasoning_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, + "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, + }, + "service_tier": None, +} + + +def test_action_lifecycle_llm_judge_structured(client: WeaveClient): + if client_is_sqlite(client): + return pytest.skip("skipping for sqlite") + + action_name = "response_is_mindful" + + published_ref = weave.publish( + ActionDefinition( + name=action_name, + spec={ + "action_type": "llm_judge", + "model": "gpt-4o-mini", + "prompt": "Is the response mindful?", + "response_format": { + "type": "object", + "properties": { + "is_mindful": {"type": "boolean"}, + "reason": {"type": "string"}, + }, + }, + }, + ) + ) + + # Construct the URI + action_ref_uri = published_ref.uri() + + @weave.op + def example_op(input: str) -> str: + return input + "." + + # Step 2: test that we can in-place execute one action at a time. + _, call = example_op.call("i've been very meditative and mindful today") + + with secret_fetcher_context(DummySecretFetcher()): + # with patch("litellm.completion") as mock_completion: + # mock_completion.return_value = ModelResponse.model_validate(mock_response) + client.server.actions_execute_batch( + ActionsExecuteBatchReq.model_validate( + { + "project_id": client._project_id(), + "action_ref": action_ref_uri, + "call_ids": [call.id], + } + ) + ) + + feedbacks = list(call.feedback) + assert len(feedbacks) == 1 + feedback = feedbacks[0] + assert feedback.feedback_type == "wandb.runnable." + action_name + assert feedback.runnable_ref == action_ref_uri + assert feedback.payload == { + "output": { + "is_mindful": True, + "reason": "The response reflects a state of being that embodies mindfulness and meditation, acknowledging a positive mental and emotional state.", + } + } diff --git a/weave/actions_worker/actions/llm_judge.py b/weave/actions_worker/actions/llm_judge.py index c8b37efb1cc..8d7507ef766 100644 --- a/weave/actions_worker/actions/llm_judge.py +++ b/weave/actions_worker/actions/llm_judge.py @@ -56,19 +56,6 @@ def do_llm_judge_action( ) ) - # client = OpenAI() - # # Silly hack to get around issue in tests: - # create = client.chat.completions.create - # if hasattr(create, "resolve_fn"): - # create = partial(create.resolve_fn, self=client.chat.completions) - # completion = create( - # model=model, - # messages=[ - # {"role": "system", "content": system_prompt}, - # {"role": "user", "content": json.dumps(args)}, - # ], - # response_format=response_format, - # ) content = ( completion.response.get("choices", [{}])[0].get("message", {}).get("content") ) From dd08ab795e4078ea2a4e285ba7fddf13e78c22c3 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 18:53:59 -0800 Subject: [PATCH 038/120] change name to response_schema --- tests/trace/test_actions_lifecycle.py | 4 ++-- weave/actions_worker/actions/llm_judge.py | 6 +++--- weave/trace_server/interface/base_object_classes/actions.py | 5 +++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/trace/test_actions_lifecycle.py b/tests/trace/test_actions_lifecycle.py index b26d2cb86f0..65562fd4ca4 100644 --- a/tests/trace/test_actions_lifecycle.py +++ b/tests/trace/test_actions_lifecycle.py @@ -142,7 +142,7 @@ def test_action_lifecycle_llm_judge_primitive(client: WeaveClient): "action_type": "llm_judge", "model": "gpt-4o-mini", "prompt": "Is the response mindful?", - "response_format": {"type": "boolean"}, + "response_schema": {"type": "boolean"}, }, ) ) @@ -227,7 +227,7 @@ def test_action_lifecycle_llm_judge_structured(client: WeaveClient): "action_type": "llm_judge", "model": "gpt-4o-mini", "prompt": "Is the response mindful?", - "response_format": { + "response_schema": { "type": "object", "properties": { "is_mindful": {"type": "boolean"}, diff --git a/weave/actions_worker/actions/llm_judge.py b/weave/actions_worker/actions/llm_judge.py index 8d7507ef766..f22cd1a1731 100644 --- a/weave/actions_worker/actions/llm_judge.py +++ b/weave/actions_worker/actions/llm_judge.py @@ -20,16 +20,16 @@ def do_llm_judge_action( model = config.model system_prompt = config.prompt - response_is_not_object = config.response_format["type"] != "object" + response_is_not_object = config.response_schema["type"] != "object" dummy_key = "response" if response_is_not_object: schema = { "type": "object", - "properties": {dummy_key: config.response_format}, + "properties": {dummy_key: config.response_schema}, "additionalProperties": False, } else: - schema = config.response_format + schema = config.response_schema response_format = { "type": "json_schema", diff --git a/weave/trace_server/interface/base_object_classes/actions.py b/weave/trace_server/interface/base_object_classes/actions.py index fd78a654492..8ef94402945 100644 --- a/weave/trace_server/interface/base_object_classes/actions.py +++ b/weave/trace_server/interface/base_object_classes/actions.py @@ -7,11 +7,12 @@ class LlmJudgeActionSpec(BaseModel): action_type: Literal["llm_judge"] = "llm_judge" - # TODO: Remove this restriction + # TODO: Remove this restriction (probably after initial release. We need to cross + # reference which LiteLLM models support structured outputs) model: Literal["gpt-4o", "gpt-4o-mini"] prompt: str # Expected to be valid JSON Schema - response_format: dict[str, Any] + response_schema: dict[str, Any] class ContainsWordsActionSpec(BaseModel): From 3bb0db573f2fbc054a091707d84754ac5d398f55 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 18:55:54 -0800 Subject: [PATCH 039/120] change name to response_schema --- weave/trace_server/clickhouse_trace_server_batched.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index a21feb77d48..17d54f3fca2 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -1413,8 +1413,10 @@ def actions_execute_batch( if len(req.call_ids) == 0: return tsi.ActionsExecuteBatchRes() if len(req.call_ids) > 1: + # This is temporary until we setup our batching infrastructure raise NotImplementedError("Batching actions is not yet supported") + # For now, we just execute in-process if it is a single action execute_batch( batch_req=req, trace_server=self, From 849368924402b720e2d76977f12d88109fe6b670 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 19:18:20 -0800 Subject: [PATCH 040/120] Fixed structured test --- tests/trace/test_actions_lifecycle.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/trace/test_actions_lifecycle.py b/tests/trace/test_actions_lifecycle.py index 65562fd4ca4..1246046363b 100644 --- a/tests/trace/test_actions_lifecycle.py +++ b/tests/trace/test_actions_lifecycle.py @@ -249,17 +249,19 @@ def example_op(input: str) -> str: _, call = example_op.call("i've been very meditative and mindful today") with secret_fetcher_context(DummySecretFetcher()): - # with patch("litellm.completion") as mock_completion: - # mock_completion.return_value = ModelResponse.model_validate(mock_response) - client.server.actions_execute_batch( - ActionsExecuteBatchReq.model_validate( - { - "project_id": client._project_id(), - "action_ref": action_ref_uri, - "call_ids": [call.id], - } + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse.model_validate( + structured_mock_response + ) + client.server.actions_execute_batch( + ActionsExecuteBatchReq.model_validate( + { + "project_id": client._project_id(), + "action_ref": action_ref_uri, + "call_ids": [call.id], + } + ) ) - ) feedbacks = list(call.feedback) assert len(feedbacks) == 1 From 73514c7d0c5ecc62b250c197b846346e8c287462 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 19:20:27 -0800 Subject: [PATCH 041/120] Added wb user id to the batch executor --- weave/actions_worker/dispatcher.py | 18 ++++++++++++------ ...xternal_to_internal_trace_server_adapter.py | 4 ++++ weave/trace_server/trace_server_interface.py | 1 + 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/weave/actions_worker/dispatcher.py b/weave/actions_worker/dispatcher.py index ed935678b2c..de9908f405d 100644 --- a/weave/actions_worker/dispatcher.py +++ b/weave/actions_worker/dispatcher.py @@ -32,7 +32,7 @@ ActionFnType = Callable[[str, ActionDefinition, CallSchema, TraceServerInterface], Any] -# TODO: Nail down this typing + dispatch_map: dict[type[ActionSpecType], ActionFnType] = { LlmJudgeActionSpec: do_llm_judge_action, ContainsWordsActionSpec: do_contains_words_action, @@ -49,6 +49,11 @@ def execute_batch( trace_server: TraceServerInterface, ) -> list[ActionResult]: project_id = batch_req.project_id + wb_user_id = batch_req.wb_user_id + if wb_user_id is None: + # We should probably relax this for online evals + raise ValueError("wb_user_id cannot be None") + # 1. Lookup the action definition parsed_ref = parse_internal_uri(batch_req.action_ref) if parsed_ref.project_id != project_id: @@ -77,11 +82,11 @@ def execute_batch( calls = calls_query.calls # 2. Dispatch the action to each call - # TODO: Some actions may be able to be batched together + # FUTURE: Some actions may be able to be batched together results = [] for call in calls: result = dispatch_action( - project_id, batch_req.action_ref, action_def, call, trace_server + project_id, batch_req.action_ref, action_def, call, wb_user_id, trace_server ) results.append(result) return results @@ -92,13 +97,14 @@ def dispatch_action( action_ref: str, action_def: ActionDefinition, target_call: CallSchema, + wb_user_id: str, trace_server: TraceServerInterface, ) -> ActionResult: action_type = type(action_def.spec) action_fn = dispatch_map[action_type] result = action_fn(project_id, action_def.spec, target_call, trace_server) feedback_res = publish_results_as_feedback( - target_call, action_ref, result, trace_server + target_call, action_ref, result, wb_user_id, trace_server ) return ActionResult(result=result, feedback_res=feedback_res) @@ -107,6 +113,7 @@ def publish_results_as_feedback( target_call: CallSchema, action_ref: str, result: Any, + wb_user_id: str, trace_server: TraceServerInterface, ) -> FeedbackCreateRes: project_id = target_call.project_id @@ -123,7 +130,6 @@ def publish_results_as_feedback( feedback_type="wandb.runnable." + action_name, runnable_ref=action_ref, payload=RunnablePayloadSchema(output=result).model_dump(), - # TODO: Make `wb_user_id` optional. - wb_user_id="", + wb_user_id=wb_user_id, ) ) diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index bc21f607180..5878bdad0ee 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -351,6 +351,10 @@ def actions_execute_batch( self, req: tsi.ActionsExecuteBatchReq ) -> tsi.ActionsExecuteBatchRes: req.project_id = self._idc.ext_to_int_project_id(req.project_id) + original_user_id = req.wb_user_id + if original_user_id is None: + raise ValueError("wb_user_id cannot be None") + req.wb_user_id = self._idc.ext_to_int_user_id(original_user_id) res = self._ref_apply(self._internal_trace_server.actions_execute_batch, req) return res diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 235c51cc8b6..b3fb3b58a7f 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -852,6 +852,7 @@ class ActionsExecuteBatchReq(BaseModel): project_id: str action_ref: str call_ids: list[str] + wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) class ActionsExecuteBatchRes(BaseModel): From 07ee02a145e19df93dbe704dfec3f6944a414376 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 19:32:02 -0800 Subject: [PATCH 042/120] Fixed the tests correctly --- tests/conftest.py | 6 + .../test_actions_lifecycle_llm_judge.py | 211 ++++++++++++++++++ tests/trace/test_actions_lifecycle.py | 161 ------------- 3 files changed, 217 insertions(+), 161 deletions(-) create mode 100644 tests/integrations/litellm/test_actions_lifecycle_llm_judge.py diff --git a/tests/conftest.py b/tests/conftest.py index 6d237c00296..b28187a3833 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -324,6 +324,12 @@ def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: req.wb_user_id = self._user_id return super().cost_create(req) + def actions_execute_batch( + self, req: tsi.ActionsExecuteBatchReq + ) -> tsi.ActionsExecuteBatchRes: + req.wb_user_id = self._user_id + return super().actions_execute_batch(req) + # https://docs.pytest.org/en/7.1.x/example/simple.html#pytest-current-test-environment-variable def get_test_name(): diff --git a/tests/integrations/litellm/test_actions_lifecycle_llm_judge.py b/tests/integrations/litellm/test_actions_lifecycle_llm_judge.py new file mode 100644 index 00000000000..59f8c17e260 --- /dev/null +++ b/tests/integrations/litellm/test_actions_lifecycle_llm_judge.py @@ -0,0 +1,211 @@ +""" +The only reason this test is here is to piggy back off the existing litellm test environment. +TODO: merge this back into `test_actions_lifecycle.py` +""" + +import os +from unittest.mock import patch + +import pytest +from litellm.assistants.main import ModelResponse + +import weave +from tests.integrations.litellm.client_completions_create_test import ( + secret_fetcher_context, +) +from tests.trace.util import client_is_sqlite +from weave.trace.weave_client import WeaveClient +from weave.trace_server.interface.base_object_classes.actions import ( + ActionDefinition, +) +from weave.trace_server.trace_server_interface import ( + ActionsExecuteBatchReq, +) + + +class DummySecretFetcher: + def fetch(self, secret_name: str) -> dict: + return { + "secrets": {secret_name: os.environ.get(secret_name, "DUMMY_SECRET_VALUE")} + } + + +primitive_mock_response = { + "id": "chatcmpl-AQPvs3DE4NQqLxorvaTPixpqq9nTD", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": '{"response":true}', + "role": "assistant", + "tool_calls": None, + "function_call": None, + }, + } + ], + "created": 1730859576, + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "system_fingerprint": "fp_0ba0d124f1", + "usage": { + "completion_tokens": 5, + "prompt_tokens": 74, + "total_tokens": 79, + "completion_tokens_details": { + "audio_tokens": 0, + "reasoning_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, + "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, + }, + "service_tier": None, +} + + +def test_action_lifecycle_llm_judge_primitive(client: WeaveClient): + if client_is_sqlite(client): + return pytest.skip("skipping for sqlite") + + action_name = "response_is_mindful" + + published_ref = weave.publish( + ActionDefinition( + name=action_name, + spec={ + "action_type": "llm_judge", + "model": "gpt-4o-mini", + "prompt": "Is the response mindful?", + "response_schema": {"type": "boolean"}, + }, + ) + ) + + # Construct the URI + action_ref_uri = published_ref.uri() + + @weave.op + def example_op(input: str) -> str: + return input + "." + + # Step 2: test that we can in-place execute one action at a time. + _, call = example_op.call("i've been very meditative and mindful today") + + with secret_fetcher_context(DummySecretFetcher()): + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse.model_validate( + primitive_mock_response + ) + client.server.actions_execute_batch( + ActionsExecuteBatchReq.model_validate( + { + "project_id": client._project_id(), + "action_ref": action_ref_uri, + "call_ids": [call.id], + } + ) + ) + + feedbacks = list(call.feedback) + assert len(feedbacks) == 1 + feedback = feedbacks[0] + assert feedback.feedback_type == "wandb.runnable." + action_name + assert feedback.runnable_ref == action_ref_uri + assert feedback.payload == {"output": True} + + +structured_mock_response = { + "id": "chatcmpl-AQQKJWQDxSvU2Ya9ool2vgcJrFuON", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": '{"is_mindful":true,"reason":"The response reflects a state of being that embodies mindfulness and meditation, acknowledging a positive mental and emotional state."}', + "role": "assistant", + "tool_calls": None, + "function_call": None, + }, + } + ], + "created": 1730861091, + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "system_fingerprint": "fp_0ba0d124f1", + "usage": { + "completion_tokens": 32, + "prompt_tokens": 84, + "total_tokens": 116, + "completion_tokens_details": { + "audio_tokens": 0, + "reasoning_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, + "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, + }, + "service_tier": None, +} + + +def test_action_lifecycle_llm_judge_structured(client: WeaveClient): + if client_is_sqlite(client): + return pytest.skip("skipping for sqlite") + + action_name = "response_is_mindful" + + published_ref = weave.publish( + ActionDefinition( + name=action_name, + spec={ + "action_type": "llm_judge", + "model": "gpt-4o-mini", + "prompt": "Is the response mindful?", + "response_schema": { + "type": "object", + "properties": { + "is_mindful": {"type": "boolean"}, + "reason": {"type": "string"}, + }, + }, + }, + ) + ) + + # Construct the URI + action_ref_uri = published_ref.uri() + + @weave.op + def example_op(input: str) -> str: + return input + "." + + # Step 2: test that we can in-place execute one action at a time. + _, call = example_op.call("i've been very meditative and mindful today") + + with secret_fetcher_context(DummySecretFetcher()): + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse.model_validate( + structured_mock_response + ) + client.server.actions_execute_batch( + ActionsExecuteBatchReq.model_validate( + { + "project_id": client._project_id(), + "action_ref": action_ref_uri, + "call_ids": [call.id], + } + ) + ) + + feedbacks = list(call.feedback) + assert len(feedbacks) == 1 + feedback = feedbacks[0] + assert feedback.feedback_type == "wandb.runnable." + action_name + assert feedback.runnable_ref == action_ref_uri + assert feedback.payload == { + "output": { + "is_mindful": True, + "reason": "The response reflects a state of being that embodies mindfulness and meditation, acknowledging a positive mental and emotional state.", + } + } diff --git a/tests/trace/test_actions_lifecycle.py b/tests/trace/test_actions_lifecycle.py index 1246046363b..7ccae08e819 100644 --- a/tests/trace/test_actions_lifecycle.py +++ b/tests/trace/test_actions_lifecycle.py @@ -1,13 +1,6 @@ -import os -from unittest.mock import patch - import pytest -from litellm.assistants.main import ModelResponse import weave -from tests.integrations.litellm.client_completions_create_test import ( - secret_fetcher_context, -) from tests.trace.util import client_is_sqlite from weave.trace.weave_client import WeaveClient from weave.trace_server.interface.base_object_classes.actions import ( @@ -19,13 +12,6 @@ ) -class DummySecretFetcher: - def fetch(self, secret_name: str) -> dict: - return { - "secrets": {secret_name: os.environ.get(secret_name, "DUMMY_SECRET_VALUE")} - } - - def test_action_lifecycle_word_count(client: WeaveClient): if client_is_sqlite(client): return pytest.skip("skipping for sqlite") @@ -127,150 +113,3 @@ def example_op(input: str) -> str: }, "service_tier": None, } - - -def test_action_lifecycle_llm_judge_primitive(client: WeaveClient): - if client_is_sqlite(client): - return pytest.skip("skipping for sqlite") - - action_name = "response_is_mindful" - - published_ref = weave.publish( - ActionDefinition( - name=action_name, - spec={ - "action_type": "llm_judge", - "model": "gpt-4o-mini", - "prompt": "Is the response mindful?", - "response_schema": {"type": "boolean"}, - }, - ) - ) - - # Construct the URI - action_ref_uri = published_ref.uri() - - @weave.op - def example_op(input: str) -> str: - return input + "." - - # Step 2: test that we can in-place execute one action at a time. - _, call = example_op.call("i've been very meditative and mindful today") - - with secret_fetcher_context(DummySecretFetcher()): - with patch("litellm.completion") as mock_completion: - mock_completion.return_value = ModelResponse.model_validate( - primitive_mock_response - ) - client.server.actions_execute_batch( - ActionsExecuteBatchReq.model_validate( - { - "project_id": client._project_id(), - "action_ref": action_ref_uri, - "call_ids": [call.id], - } - ) - ) - - feedbacks = list(call.feedback) - assert len(feedbacks) == 1 - feedback = feedbacks[0] - assert feedback.feedback_type == "wandb.runnable." + action_name - assert feedback.runnable_ref == action_ref_uri - assert feedback.payload == {"output": True} - - -structured_mock_response = { - "id": "chatcmpl-AQQKJWQDxSvU2Ya9ool2vgcJrFuON", - "choices": [ - { - "finish_reason": "stop", - "index": 0, - "message": { - "content": '{"is_mindful":true,"reason":"The response reflects a state of being that embodies mindfulness and meditation, acknowledging a positive mental and emotional state."}', - "role": "assistant", - "tool_calls": None, - "function_call": None, - }, - } - ], - "created": 1730861091, - "model": "gpt-4o-mini-2024-07-18", - "object": "chat.completion", - "system_fingerprint": "fp_0ba0d124f1", - "usage": { - "completion_tokens": 32, - "prompt_tokens": 84, - "total_tokens": 116, - "completion_tokens_details": { - "audio_tokens": 0, - "reasoning_tokens": 0, - "accepted_prediction_tokens": 0, - "rejected_prediction_tokens": 0, - }, - "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, - }, - "service_tier": None, -} - - -def test_action_lifecycle_llm_judge_structured(client: WeaveClient): - if client_is_sqlite(client): - return pytest.skip("skipping for sqlite") - - action_name = "response_is_mindful" - - published_ref = weave.publish( - ActionDefinition( - name=action_name, - spec={ - "action_type": "llm_judge", - "model": "gpt-4o-mini", - "prompt": "Is the response mindful?", - "response_schema": { - "type": "object", - "properties": { - "is_mindful": {"type": "boolean"}, - "reason": {"type": "string"}, - }, - }, - }, - ) - ) - - # Construct the URI - action_ref_uri = published_ref.uri() - - @weave.op - def example_op(input: str) -> str: - return input + "." - - # Step 2: test that we can in-place execute one action at a time. - _, call = example_op.call("i've been very meditative and mindful today") - - with secret_fetcher_context(DummySecretFetcher()): - with patch("litellm.completion") as mock_completion: - mock_completion.return_value = ModelResponse.model_validate( - structured_mock_response - ) - client.server.actions_execute_batch( - ActionsExecuteBatchReq.model_validate( - { - "project_id": client._project_id(), - "action_ref": action_ref_uri, - "call_ids": [call.id], - } - ) - ) - - feedbacks = list(call.feedback) - assert len(feedbacks) == 1 - feedback = feedbacks[0] - assert feedback.feedback_type == "wandb.runnable." + action_name - assert feedback.runnable_ref == action_ref_uri - assert feedback.payload == { - "output": { - "is_mindful": True, - "reason": "The response reflects a state of being that embodies mindfulness and meditation, acknowledging a positive mental and emotional state.", - } - } From f71a0e1e7806c8a3f875638fd688f57cbfaa6bc1 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 19:39:17 -0800 Subject: [PATCH 043/120] Init Changes --- .../components/FancyPage/useProjectSidebar.ts | 13 + weave-js/src/components/Form/AutoComplete.tsx | 4 - weave-js/src/components/Form/Select.tsx | 16 +- .../Home/Browse2/SmallRef.tsx | 7 +- .../PagePanelComponents/Home/Browse3.tsx | 17 +- .../Home/Browse3/DynamicConfigForm.tsx | 826 ++++++++++++++++++ .../Home/Browse3/ReusableDrawer.tsx | 64 ++ .../Browse3/collections/actionCollection.ts | 51 ++ .../Browse3/collections/collectionRegistry.ts | 9 + .../collections/getCollectionObjects.tsx | 118 +++ .../Home/Browse3/context.tsx | 15 + .../pages/CallPage/CallActionsViewer.tsx | 277 ++++++ .../Home/Browse3/pages/CallPage/CallPage.tsx | 16 + .../Browse3/pages/CallPage/ObjectViewer.tsx | 8 +- .../Home/Browse3/pages/ObjectVersionPage.tsx | 2 + .../OpVersionPage/OpOnlineScorersTab.tsx | 289 ++++++ .../pages/OpVersionPage/OpVersionPage.tsx | 163 ++++ .../NewBuiltInActionScorerModal.tsx | 155 ++++ .../Browse3/pages/ScorersPage/ScorersPage.tsx | 219 +++++ .../pages/ScorersPage/actionTemplates.tsx | 58 ++ .../Browse3/pages/common/SimplePageLayout.tsx | 33 +- .../pages/common/TypeVersionCategoryChip.tsx | 2 + .../pages/wfReactInterface/constants.ts | 2 + .../traceServerClientTypes.ts | 8 + .../traceServerDirectClient.ts | 25 +- .../Home/Browse3/windowFlags.tsx | 103 +++ .../Panel2/PanelTable/tableState.ts | 23 +- 27 files changed, 2461 insertions(+), 62 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/DynamicConfigForm.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/ReusableDrawer.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/collections/getCollectionObjects.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpOnlineScorersTab.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionTemplates.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/windowFlags.tsx diff --git a/weave-js/src/components/FancyPage/useProjectSidebar.ts b/weave-js/src/components/FancyPage/useProjectSidebar.ts index c9d0b928997..c33b6bfe90a 100644 --- a/weave-js/src/components/FancyPage/useProjectSidebar.ts +++ b/weave-js/src/components/FancyPage/useProjectSidebar.ts @@ -2,6 +2,10 @@ import {IconNames} from '@wandb/weave/components/Icon'; import _ from 'lodash'; import {useMemo} from 'react'; +import { + ENABLE_ONLINE_EVAL_UI, + getFeatureFlag, +} from '../PagePanelComponents/Home/Browse3/windowFlags'; import {FancyPageSidebarItem} from './FancyPageSidebar'; export const useProjectSidebar = ( @@ -31,6 +35,7 @@ export const useProjectSidebar = ( const isNoSidebarItems = !showModelsSidebarItems && !showWeaveSidebarItems; const isBothSidebarItems = showModelsSidebarItems && showWeaveSidebarItems; const isShowAll = isNoSidebarItems || isBothSidebarItems; + const enableOnlineEvalUI = getFeatureFlag(ENABLE_ONLINE_EVAL_UI); return useMemo(() => { const allItems = isLoading @@ -183,6 +188,13 @@ export const useProjectSidebar = ( isShown: showWeaveSidebarItems || isShowAll, iconName: IconNames.Table, }, + { + type: 'button' as const, + name: 'Scorers', + slug: 'weave/scorers', + isShown: enableOnlineEvalUI && (showWeaveSidebarItems || isShowAll), + iconName: IconNames.TypeNumberAlt, + }, { type: 'divider' as const, key: 'dividerWithinWeave-3', @@ -243,5 +255,6 @@ export const useProjectSidebar = ( viewingRestricted, isModelsOnly, showWeaveSidebarItems, + enableOnlineEvalUI, ]); }; diff --git a/weave-js/src/components/Form/AutoComplete.tsx b/weave-js/src/components/Form/AutoComplete.tsx index 4fba92aec38..6455790fb3c 100644 --- a/weave-js/src/components/Form/AutoComplete.tsx +++ b/weave-js/src/components/Form/AutoComplete.tsx @@ -63,10 +63,6 @@ const getStyles = (props: AdditionalProps) => { minHeight: `${HEIGHTS[size]} !important`, }, }, - '& .MuiAutocomplete-popupIndicator': { - borderRadius: '4px', - padding: '4px', - }, '&.MuiAutocomplete-hasPopupIcon .MuiOutlinedInput-root, &.MuiAutocomplete-hasClearIcon .MuiOutlinedInput-root': { paddingRight: props.hasInputValue ? '28px' : '0px', // Apply padding only if input exists diff --git a/weave-js/src/components/Form/Select.tsx b/weave-js/src/components/Form/Select.tsx index 2163b6af180..30da9ea00fe 100644 --- a/weave-js/src/components/Form/Select.tsx +++ b/weave-js/src/components/Form/Select.tsx @@ -16,8 +16,7 @@ import { MOON_800, RED_550, TEAL_300, - TEAL_350, - TEAL_400, + TEAL_500, TEAL_600, } from '@wandb/weave/common/css/globals.styles'; import {Icon} from '@wandb/weave/components/Icon'; @@ -205,8 +204,10 @@ const getStyles = < }, control: (baseStyles, state) => { const colorBorderDefault = MOON_250; - const colorBorderHover = TEAL_350; - const colorBorderOpen = errorState ? hexToRGB(RED_550, 0.64) : TEAL_400; + const colorBorderHover = hexToRGB(TEAL_500, 0.4); + const colorBorderOpen = errorState + ? hexToRGB(RED_550, 0.64) + : hexToRGB(TEAL_500, 0.64); const height = HEIGHTS[size]; const minHeight = MIN_HEIGHTS[size] ?? height; const lineHeight = LINE_HEIGHTS[size]; @@ -225,10 +226,9 @@ const getStyles = < ? `0 0 0 2px ${colorBorderOpen}` : `inset 0 0 0 1px ${colorBorderDefault}`, '&:hover': { - boxShadow: - state.menuIsOpen || state.isFocused - ? `0 0 0 2px ${colorBorderOpen}` - : `0 0 0 2px ${colorBorderHover}`, + boxShadow: state.menuIsOpen + ? `0 0 0 2px ${colorBorderOpen}` + : `0 0 0 2px ${colorBorderHover}`, }, }; }, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse2/SmallRef.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse2/SmallRef.tsx index e6f4c0955a8..832a9c6b39e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse2/SmallRef.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse2/SmallRef.tsx @@ -80,7 +80,8 @@ export const SmallRef: FC<{ objRef: ObjectRef; wfTable?: WFDBTableType; iconOnly?: boolean; -}> = ({objRef, wfTable, iconOnly = false}) => { + nameOnly?: boolean; +}> = ({objRef, wfTable, iconOnly = false, nameOnly = false}) => { const { useObjectVersion, useOpVersion, @@ -140,7 +141,9 @@ export const SmallRef: FC<{ // TODO: Why is this necessary? The type is coming back as `objRef` rootType = {type: 'OpDef'}; } - const {label} = objectRefDisplayName(objRef, versionIndex); + const {label} = nameOnly + ? {label: objRef.artifactName} + : objectRefDisplayName(objRef, versionIndex); const rootTypeName = getTypeName(rootType); let icon: IconName = IconNames.CubeContainer; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index 2c458814245..f73b62447ce 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -92,8 +92,9 @@ import { } from './Browse3/pages/ObjectVersionsPage'; import {OpPage} from './Browse3/pages/OpPage'; import {OpsPage} from './Browse3/pages/OpsPage'; -import {OpVersionPage} from './Browse3/pages/OpVersionPage'; +import {OpVersionPage} from './Browse3/pages/OpVersionPage/OpVersionPage'; import {OpVersionsPage} from './Browse3/pages/OpVersionsPage'; +import {ScorersPage} from './Browse3/pages/ScorersPage/ScorersPage'; import {TablePage} from './Browse3/pages/TablePage'; import {TablesPage} from './Browse3/pages/TablesPage'; import {useURLSearchParamsDict} from './Browse3/pages/util'; @@ -102,6 +103,7 @@ import { WFDataModelAutoProvider, } from './Browse3/pages/wfReactInterface/context'; import {useHasTraceServerClientContext} from './Browse3/pages/wfReactInterface/traceServerClientContext'; +import {ENABLE_ONLINE_EVAL_UI, getFeatureFlag} from './Browse3/windowFlags'; import {useDrawerResize} from './useDrawerResize'; LicenseInfo.setLicenseKey( @@ -156,6 +158,7 @@ const tabOptions = [ 'leaderboards', 'boards', 'tables', + 'scorers', ]; const tabs = tabOptions.join('|'); const browse3Paths = (projectRoot: string) => [ @@ -497,6 +500,9 @@ const Browse3ProjectRoot: FC<{ + + + { ); }; +const ScorersPageBinding = () => { + const {entity, project} = useParamsDecoded(); + const enableOnlineEvalUI = getFeatureFlag(ENABLE_ONLINE_EVAL_UI); + if (!enableOnlineEvalUI) { + return null; + } + return ; +}; + const LeaderboardPageBinding = () => { const params = useParamsDecoded(); const {entity, project, itemName: leaderboardName} = params; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/DynamicConfigForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/DynamicConfigForm.tsx new file mode 100644 index 00000000000..e5561c77f9c --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/DynamicConfigForm.tsx @@ -0,0 +1,826 @@ +import { + Box, + Button, + Checkbox, + FormControl, + FormControlLabel, + IconButton, + InputLabel, + MenuItem, + Select, + TextField, + Typography, +} from '@material-ui/core'; +import {Delete} from '@mui/icons-material'; +import React, {useEffect, useMemo, useState} from 'react'; +import {z} from 'zod'; + +import {parseRefMaybe} from '../Browse2/SmallRef'; + +interface DynamicConfigFormProps { + configSchema: z.ZodType; + config: Record; + setConfig: (config: Record) => void; + path?: string[]; + onValidChange?: (isValid: boolean) => void; +} + +const isZodType = ( + schema: z.ZodTypeAny, + predicate: (s: z.ZodTypeAny) => boolean +): boolean => { + if (predicate(schema)) { + return true; + } + if (schema instanceof z.ZodOptional || schema instanceof z.ZodDefault) { + return isZodType(unwrapSchema(schema), predicate); + } + if (schema instanceof z.ZodDiscriminatedUnion) { + return true; + } + return false; +}; + +const unwrapSchema = (schema: z.ZodTypeAny): z.ZodTypeAny => { + if (schema instanceof z.ZodOptional || schema instanceof z.ZodDefault) { + return unwrapSchema(schema._def.innerType); + } + if (schema instanceof z.ZodDiscriminatedUnion) { + return schema; + } + return schema; +}; + +const DiscriminatedUnionField: React.FC<{ + keyName: string; + fieldSchema: z.ZodDiscriminatedUnion< + string, + Array> + >; + targetPath: string[]; + value: any; + config: Record; + setConfig: (config: Record) => void; +}> = ({keyName, fieldSchema, targetPath, value, config, setConfig}) => { + const discriminator = fieldSchema._def.discriminator; + const options = fieldSchema._def.options; + + const currentType = + value?.[discriminator] || options[0]._def.shape()[discriminator]._def.value; + + const handleTypeChange = (newType: string) => { + const selectedOption = options.find( + option => option._def.shape()[discriminator]._def.value === newType + ); + if (selectedOption) { + const newValue = {[discriminator]: newType}; + Object.keys(selectedOption.shape).forEach(key => { + if (key !== discriminator) { + newValue[key] = + selectedOption.shape[key] instanceof z.ZodDefault + ? selectedOption.shape[key]._def.defaultValue() + : undefined; + } + }); + updateConfig(targetPath, newValue, config, setConfig); + } + }; + + const selectedSchema = options.find( + option => option._def.shape()[discriminator]._def.value === currentType + )!; + + // Create a new schema without the discriminator field + const filteredSchema = z.object( + Object.entries(selectedSchema.shape).reduce((acc, [key, innerValue]) => { + if (key !== discriminator) { + acc[key] = innerValue as z.ZodTypeAny; + } + return acc; + }, {} as Record) + ); + + return ( + + {keyName} + + + { + const updatedConfig = {...newConfig, [discriminator]: currentType}; + updateConfig(targetPath, updatedConfig, config, setConfig); + }} + path={[]} + /> + + + ); +}; + +const NestedForm: React.FC<{ + keyName: string; + fieldSchema: z.ZodTypeAny; + config: Record; + setConfig: (config: Record) => void; + path: string[]; +}> = ({keyName, fieldSchema, config, setConfig, path}) => { + const currentPath = [...path, keyName]; + const currentValue = getNestedValue(config, currentPath); + + const unwrappedSchema = unwrapSchema(fieldSchema); + + if (unwrappedSchema instanceof z.ZodDiscriminatedUnion) { + return ( + + ); + } + + if (isZodType(fieldSchema, s => s instanceof z.ZodObject)) { + return ( + + {keyName} + + } + config={config} + setConfig={setConfig} + path={currentPath} + /> + + + ); + } + + if (isZodType(fieldSchema, s => s instanceof z.ZodArray)) { + return ( + } + targetPath={currentPath} + value={currentValue} + config={config} + setConfig={setConfig} + /> + ); + } + + if (isZodType(fieldSchema, s => s instanceof z.ZodEnum)) { + return ( + } + targetPath={currentPath} + value={currentValue} + config={config} + setConfig={setConfig} + /> + ); + } + + if (isZodType(fieldSchema, s => s instanceof z.ZodRecord)) { + return ( + } + targetPath={currentPath} + value={currentValue} + config={config} + setConfig={setConfig} + /> + ); + } + + if (isZodType(fieldSchema, s => s instanceof z.ZodNumber)) { + return ( + + ); + } + + if (isZodType(fieldSchema, s => s instanceof z.ZodLiteral)) { + return ( + } + targetPath={currentPath} + value={currentValue} + config={config} + setConfig={setConfig} + /> + ); + } + + if (isZodType(fieldSchema, s => s instanceof z.ZodBoolean)) { + return ( + + ); + } + + let fieldType = 'text'; + if (isZodType(fieldSchema, s => s instanceof z.ZodNumber)) { + fieldType = 'number'; + } else if (isZodType(fieldSchema, s => s instanceof z.ZodBoolean)) { + fieldType = 'checkbox'; + } + + return ( + + updateConfig(currentPath, e.target.value, config, setConfig) + } + margin="dense" + /> + ); +}; + +const ArrayField: React.FC<{ + keyName: string; + fieldSchema: z.ZodTypeAny; + unwrappedSchema: z.ZodArray; + targetPath: string[]; + value: any[]; + config: Record; + setConfig: (config: Record) => void; +}> = ({ + keyName, + fieldSchema, + unwrappedSchema, + targetPath, + value, + config, + setConfig, +}) => { + const arrayValue = useMemo( + () => (Array.isArray(value) ? value : []), + [value] + ); + const minItems = unwrappedSchema._def.minLength?.value ?? 0; + const elementSchema = unwrappedSchema.element; + + // Ensure the minimum number of items is always present + React.useEffect(() => { + if (arrayValue.length < minItems) { + const itemsToAdd = minItems - arrayValue.length; + const newItems = Array(itemsToAdd) + .fill(null) + .map(() => (elementSchema instanceof z.ZodObject ? {} : null)); + updateConfig(targetPath, [...arrayValue, ...newItems], config, setConfig); + } + }, [arrayValue, minItems, elementSchema, targetPath, config, setConfig]); + + return ( + + {keyName} + {arrayValue.map((item, index) => ( + + + { + const newArray = [...arrayValue]; + newArray[index] = newItemConfig[`${index}`]; + updateConfig(targetPath, newArray, config, setConfig); + }} + path={[]} + /> + + + + removeArrayItem(targetPath, index, config, setConfig) + } + disabled={arrayValue.length <= minItems}> + + + + + ))} + + + ); +}; + +const EnumField: React.FC<{ + keyName: string; + fieldSchema: z.ZodTypeAny; + unwrappedSchema: z.ZodEnum; + targetPath: string[]; + value: any; + config: Record; + setConfig: (config: Record) => void; +}> = ({ + keyName, + fieldSchema, + unwrappedSchema, + targetPath, + value, + config, + setConfig, +}) => { + const options = unwrappedSchema.options; + + // Determine the default value + const defaultValue = React.useMemo(() => { + if (fieldSchema instanceof z.ZodDefault) { + return fieldSchema._def.defaultValue(); + } + if (options.length > 0) { + return options[0]; + } + return undefined; + }, [fieldSchema, options]); + // Use the default value if the current value is null or undefined + const selectedValue = value ?? defaultValue; + + useEffect(() => { + if (value === null || value === undefined) { + updateConfig(targetPath, selectedValue, config, setConfig); + } + }, [value, selectedValue, targetPath, config, setConfig]); + + return ( + + {keyName !== '' ? ( + {keyName} + ) : ( +
+ )} + + + ); +}; + +const RecordField: React.FC<{ + keyName: string; + fieldSchema: z.ZodTypeAny; + unwrappedSchema: z.ZodRecord; + targetPath: string[]; + value: Record; + config: Record; + setConfig: (config: Record) => void; +}> = ({ + keyName, + fieldSchema, + unwrappedSchema, + targetPath, + value, + config, + setConfig, +}) => { + const [internalPairs, setInternalPairs] = useState< + Array<{key: string; value: any}> + >([]); + + const valueSchema = unwrappedSchema._def.valueType; + const unwrappedValueSchema = unwrapSchema(valueSchema); + + // Initialize or update internalPairs when value changes + useEffect(() => { + if (value && typeof value === 'object') { + setInternalPairs( + Object.entries(value).map(([key, val]) => ({key, value: val})) + ); + } else { + setInternalPairs([]); + } + }, [value]); + + const updateInternalPair = (index: number, newKey: string, newValue: any) => { + const newPairs = [...internalPairs]; + newPairs[index] = {key: newKey, value: newValue}; + setInternalPairs(newPairs); + + // Update the actual config + const newRecord = newPairs.reduce((acc, {key, value: innerValue}) => { + acc[key] = innerValue; + return acc; + }, {} as Record); + updateConfig(targetPath, newRecord, config, setConfig); + }; + + const addNewPair = () => { + const newKey = `key${internalPairs.length + 1}`; + let defaultValue: any = ''; + if (valueSchema instanceof z.ZodDefault) { + defaultValue = valueSchema._def.defaultValue(); + } else if (valueSchema instanceof z.ZodEnum) { + defaultValue = valueSchema.options[0]; + } else if (valueSchema instanceof z.ZodBoolean) { + defaultValue = false; + } else if (valueSchema instanceof z.ZodNumber) { + defaultValue = 0; + } + setInternalPairs([...internalPairs, {key: newKey, value: defaultValue}]); + updateConfig( + targetPath, + {...value, [newKey]: defaultValue}, + config, + setConfig + ); + }; + + const removePair = (index: number) => { + const newPairs = internalPairs.filter((_, i) => i !== index); + setInternalPairs(newPairs); + + const newRecord = newPairs.reduce((acc, {key, value: innerValue}) => { + acc[key] = innerValue; + return acc; + }, {} as Record); + updateConfig(targetPath, newRecord, config, setConfig); + }; + + return ( + + {keyName} + {internalPairs.map(({key, value: innerValue}, index) => ( + + { + if ( + internalPairs.some( + (pair, i) => i !== index && pair.key === e.target.value + ) + ) { + // Prevent duplicate keys + return; + } + updateInternalPair(index, e.target.value, innerValue); + }} + margin="dense" + /> + {isZodType(valueSchema, s => s instanceof z.ZodEnum) ? ( + } + targetPath={[...targetPath, key]} + value={innerValue} + config={config} + setConfig={newConfig => { + const newValue = getNestedValue(newConfig, [ + ...targetPath, + key, + ]); + updateInternalPair(index, key, newValue); + }} + /> + ) : ( + updateInternalPair(index, key, e.target.value)} + margin="dense" + /> + )} + removePair(index)}> + + + + ))} + + + ); +}; + +const getNestedValue = (obj: any, targetPath: string[]): any => { + return targetPath.reduce( + (acc, key) => (acc && acc[key] !== undefined ? acc[key] : undefined), + obj + ); +}; + +const updateConfig = ( + targetPath: string[], + value: any, + config: Record, + setConfig: (config: Record) => void +) => { + const newConfig = {...config}; + let current = newConfig; + for (let i = 0; i < targetPath.length - 1; i++) { + if (!(targetPath[i] in current)) { + current[targetPath[i]] = {}; + } + current = current[targetPath[i]]; + } + + // Convert OrderedRecord to plain object if necessary + if ( + value && + typeof value === 'object' && + 'keys' in value && + 'values' in value + ) { + const plainObject: Record = {}; + value.keys.forEach((key: string) => { + plainObject[key] = value.values[key]; + }); + current[targetPath[targetPath.length - 1]] = plainObject; + } else { + current[targetPath[targetPath.length - 1]] = value; + } + + setConfig(newConfig); +}; + +const addArrayItem = ( + targetPath: string[], + elementSchema: z.ZodTypeAny, + config: Record, + setConfig: (config: Record) => void +) => { + const currentArray = getNestedValue(config, targetPath) || []; + const newItem = elementSchema instanceof z.ZodObject ? {} : null; + updateConfig(targetPath, [...currentArray, newItem], config, setConfig); +}; + +const removeArrayItem = ( + targetPath: string[], + index: number, + config: Record, + setConfig: (config: Record) => void +) => { + const currentArray = getNestedValue(config, targetPath) || []; + const unwrappedSchema = unwrapSchema( + getNestedValue(config, targetPath.slice(0, -1)) + ); + const minItems = + unwrappedSchema instanceof z.ZodArray + ? unwrappedSchema._def.minLength?.value ?? 0 + : 0; + + if (currentArray.length > minItems) { + updateConfig( + targetPath, + currentArray.filter((_: any, i: number) => i !== index), + config, + setConfig + ); + } +}; + +const NumberField: React.FC<{ + keyName: string; + fieldSchema: z.ZodTypeAny; + unwrappedSchema: z.ZodNumber; + targetPath: string[]; + value: number | undefined; + config: Record; + setConfig: (config: Record) => void; +}> = ({ + keyName, + fieldSchema, + unwrappedSchema, + targetPath, + value, + config, + setConfig, +}) => { + const min = + (unwrappedSchema._def.checks.find(check => check.kind === 'min') as any) + ?.value ?? undefined; + const max = + (unwrappedSchema._def.checks.find(check => check.kind === 'max') as any) + ?.value ?? undefined; + // const defaultValue = + // fieldSchema instanceof z.ZodDefault + // ? fieldSchema._def.defaultValue() + // : undefined; + + // useEffect(() => { + // if (value === undefined && defaultValue !== undefined) { + // updateConfig(targetPath, defaultValue, config, setConfig); + // } + // }, [value, defaultValue, targetPath, config, setConfig]); + + return ( + { + const newValue = + e.target.value === '' ? undefined : Number(e.target.value); + if (newValue !== undefined && (newValue < min || newValue > max)) { + return; + } + updateConfig(targetPath, newValue, config, setConfig); + }} + inputProps={{min, max}} + margin="dense" + /> + ); +}; + +const LiteralField: React.FC<{ + keyName: string; + fieldSchema: z.ZodTypeAny; + unwrappedSchema: z.ZodLiteral; + targetPath: string[]; + value: any; + config: Record; + setConfig: (config: Record) => void; +}> = ({ + keyName, + fieldSchema, + unwrappedSchema, + targetPath, + value, + config, + setConfig, +}) => { + const literalValue = unwrappedSchema.value; + + useEffect(() => { + if (value !== literalValue) { + updateConfig(targetPath, literalValue, config, setConfig); + } + }, [value, literalValue, targetPath, config, setConfig]); + + return ( + + ); +}; + +const BooleanField: React.FC<{ + keyName: string; + fieldSchema: z.ZodTypeAny; + unwrappedSchema: z.ZodBoolean; + targetPath: string[]; + value: boolean | undefined; + config: Record; + setConfig: (config: Record) => void; +}> = ({ + keyName, + fieldSchema, + unwrappedSchema, + targetPath, + value, + config, + setConfig, +}) => { + const defaultValue = + fieldSchema instanceof z.ZodDefault + ? fieldSchema._def.defaultValue() + : false; + + useEffect(() => { + if (value === undefined) { + updateConfig(targetPath, defaultValue, config, setConfig); + } + }, [value, defaultValue, targetPath, config, setConfig]); + + return ( + + updateConfig(targetPath, e.target.checked, config, setConfig) + } + /> + } + label={keyName} + /> + ); +}; + +export const DynamicConfigForm: React.FC = ({ + configSchema, + config, + setConfig, + path = [], + onValidChange, +}) => { + useEffect(() => { + const validationResult = configSchema.safeParse(config); + if (onValidChange) { + onValidChange(validationResult.success); + } + }, [config, configSchema, onValidChange]); + + const renderContent = () => { + if (configSchema instanceof z.ZodRecord) { + return ( + + ); + } else if (configSchema instanceof z.ZodObject) { + return Object.entries(configSchema.shape).map(([key, fieldSchema]) => ( + + )); + } else { + console.error('Unsupported schema type', configSchema); + return Unsupported schema type; + } + }; + + return <>{renderContent()}; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/ReusableDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/ReusableDrawer.tsx new file mode 100644 index 00000000000..e90e04e51e1 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/ReusableDrawer.tsx @@ -0,0 +1,64 @@ +import {Box, Button, Drawer, Typography} from '@material-ui/core'; +import React, {FC, ReactNode} from 'react'; + +interface ReusableDrawerProps { + open: boolean; + title: string; + onClose: () => void; + onSave: () => void; + saveDisabled?: boolean; + children: ReactNode; +} + +export const ReusableDrawer: FC = ({ + open, + title, + onClose, + onSave, + saveDisabled, + children, +}) => { + return ( + { + // do nothing + return; + }} + ModalProps={{ + keepMounted: true, // Better open performance on mobile + }}> + + + {title} + + + {children} + + + + + + + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts new file mode 100644 index 00000000000..c79c820b2a4 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts @@ -0,0 +1,51 @@ +import {z} from 'zod'; + +const JSONTypeNames = z.enum(['boolean', 'number', 'string']); +const SimpleJsonResponseFormat = z.object({type: JSONTypeNames}); +const ObjectJsonResponseFormat = z.object({ + type: z.literal('object'), + properties: z.record(SimpleJsonResponseFormat), + additionalProperties: z.literal(false), +}); + +export const ConfiguredLlmJudgeActionSchema = z.object({ + action_type: z.literal('llm_judge'), + model: z.enum(['gpt-4o-mini', 'gpt-4o']).default('gpt-4o-mini'), + prompt: z.string(), + response_format: z.discriminatedUnion('type', [ + SimpleJsonResponseFormat, + ObjectJsonResponseFormat, + ]), +}); +export type ConfiguredLlmJudgeActionType = z.infer< + typeof ConfiguredLlmJudgeActionSchema +>; + +export const ConfiguredWordCountActionSchema = z.object({ + action_type: z.literal('wordcount'), +}); +export type ConfiguredWordCountActionType = z.infer< + typeof ConfiguredWordCountActionSchema +>; + +export const ActionConfigSchema = z.discriminatedUnion('action_type', [ + ConfiguredLlmJudgeActionSchema, + ConfiguredWordCountActionSchema, +]); +export type ActionConfigType = z.infer; + +export const ConfiguredActionSchema = z.object({ + name: z.string(), + config: ActionConfigSchema, +}); +export type ConfiguredActionType = z.infer; + +export const ActionDispatchFilterSchema = z.object({ + op_name: z.string(), + sample_rate: z.number().min(0).max(1).default(1), + configured_action_ref: z.string(), + disabled: z.boolean().optional(), +}); +export type ActionDispatchFilterType = z.infer< + typeof ActionDispatchFilterSchema +>; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts new file mode 100644 index 00000000000..8698054e3d4 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts @@ -0,0 +1,9 @@ +import { + ActionDispatchFilterSchema, + ConfiguredActionSchema, +} from './actionCollection'; + +export const collectionRegistry = { + ConfiguredAction: ConfiguredActionSchema, + ActionDispatchFilter: ActionDispatchFilterSchema, +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/getCollectionObjects.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/getCollectionObjects.tsx new file mode 100644 index 00000000000..1f3707aff05 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/getCollectionObjects.tsx @@ -0,0 +1,118 @@ +import {useDeepMemo} from '@wandb/weave/hookUtils'; +import {useEffect, useState} from 'react'; +import {z} from 'zod'; + +import {TraceServerClient} from '../pages/wfReactInterface/traceServerClient'; +import {useGetTraceServerClientContext} from '../pages/wfReactInterface/traceServerClientContext'; +import { + TraceObjCreateReq, + TraceObjQueryReq, + TraceObjSchema, +} from '../pages/wfReactInterface/traceServerClientTypes'; +import {collectionRegistry} from './collectionRegistry'; + +export const useCollectionObjects = < + C extends keyof typeof collectionRegistry, + T extends z.infer<(typeof collectionRegistry)[C]> +>( + collectionName: C, + req: TraceObjQueryReq +) => { + const [objects, setObjects] = useState>>([]); + const getTsClient = useGetTraceServerClientContext(); + const client = getTsClient(); + const deepReq = useDeepMemo(req); + + useEffect(() => { + let isMounted = true; + getCollectionObjects(client, collectionName, deepReq).then( + collectionObjects => { + if (isMounted) { + setObjects(collectionObjects as Array>); + } + } + ); + return () => { + isMounted = false; + }; + }, [client, collectionName, deepReq]); + + return objects; +}; + +const getCollectionObjects = async < + C extends keyof typeof collectionRegistry, + T extends z.infer<(typeof collectionRegistry)[C]> +>( + client: TraceServerClient, + collectionName: C, + req: TraceObjQueryReq +): Promise>> => { + const knownCollection = collectionRegistry[collectionName]; + if (!knownCollection) { + console.warn(`Unknown collection: ${collectionName}`); + return []; + } + + const reqWithCollection: TraceObjQueryReq = { + ...req, + filter: {...req.filter, base_object_classes: [collectionName]}, + }; + + const objectPromise = client.objsQuery(reqWithCollection); + + const objects = await objectPromise; + + return objects.objs + .map(obj => ({obj, parsed: knownCollection.safeParse(obj.val)})) + .filter(({parsed}) => parsed.success) + .map(({obj, parsed}) => ({...obj, val: parsed.data!})) as Array< + TraceObjSchema + >; +}; + +export const useCreateCollectionObject = < + C extends keyof typeof collectionRegistry, + T extends z.infer<(typeof collectionRegistry)[C]> +>( + collectionName: C +) => { + const getTsClient = useGetTraceServerClientContext(); + const client = getTsClient(); + return (req: TraceObjCreateReq) => + createCollectionObject(client, collectionName, req); +}; + +const createCollectionObject = async < + C extends keyof typeof collectionRegistry, + T extends z.infer<(typeof collectionRegistry)[C]> +>( + client: TraceServerClient, + collectionName: C, + req: TraceObjCreateReq +) => { + const knownCollection = collectionRegistry[collectionName]; + if (!knownCollection) { + throw new Error(`Unknown collection: ${collectionName}`); + } + + const verifiedObject = knownCollection.safeParse(req.obj.val); + + if (!verifiedObject.success) { + throw new Error( + `Invalid object: ${JSON.stringify(verifiedObject.error.errors)}` + ); + } + + const reqWithCollection: TraceObjCreateReq = { + ...req, + obj: { + ...req.obj, + val: {...req.obj.val, _bases: [collectionName, 'BaseModel']}, + }, + }; + + const createPromse = client.objCreate(reqWithCollection); + + return createPromse; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/context.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/context.tsx index f453e583c54..fddcd4fbda4 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/context.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/context.tsx @@ -193,6 +193,10 @@ export const browse2Context = { ) => { throw new Error('Not implemented'); }, + + metricsUIUrl: (entityName: string, projectName: string) => { + throw new Error('Not implemented'); + }, leaderboardsUIUrl: ( entityName: string, projectName: string, @@ -430,6 +434,11 @@ export const browse3ContextGen = ( JSON.stringify(evaluationCallIds) )}${metricsPart}`; }, + + metricsUIUrl: (entityName: string, projectName: string) => { + return `${projectRoot(entityName, projectName)}/metrics`; + }, + leaderboardsUIUrl: ( entityName: string, projectName: string, @@ -524,6 +533,9 @@ type RouteType = { evaluationCallIds: string[], metrics: Record | null ) => string; + + metricsUIUrl: (entityName: string, projectName: string) => string; + leaderboardsUIUrl: ( entityName: string, projectName: string, @@ -643,6 +655,9 @@ const useMakePeekingRouter = (): RouteType => { baseContext.compareEvaluationsUri(...args) ); }, + metricsUIUrl: (...args: Parameters) => { + return setSearchParam(PEEK_PARAM, baseContext.metricsUIUrl(...args)); + }, leaderboardsUIUrl: ( ...args: Parameters ) => { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx new file mode 100644 index 00000000000..520bced46de --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx @@ -0,0 +1,277 @@ +import {Button} from '@wandb/weave/components/Button/Button'; +import {Timestamp} from '@wandb/weave/components/Timestamp'; +import {parseRef} from '@wandb/weave/react'; +import {makeRefCall} from '@wandb/weave/util/refs'; +import React, {useCallback, useMemo, useState} from 'react'; +import {z} from 'zod'; + +import {CellValue} from '../../../Browse2/CellValue'; +import {NotApplicable} from '../../../Browse2/NotApplicable'; +import {ConfiguredActionType} from '../../collections/actionCollection'; +import {useCollectionObjects} from '../../collections/getCollectionObjects'; +import {StyledDataGrid} from '../../StyledDataGrid'; // Import the StyledDataGrid component +import {WEAVE_REF_SCHEME} from '../wfReactInterface/constants'; +import {useWFHooks} from '../wfReactInterface/context'; +import {useGetTraceServerClientContext} from '../wfReactInterface/traceServerClientContext'; +import {Feedback} from '../wfReactInterface/traceServerClientTypes'; +import { + convertISOToDate, + projectIdFromParts, +} from '../wfReactInterface/tsDataModelHooks'; +import {objectVersionKeyToRefUri} from '../wfReactInterface/utilities'; +import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; + +type CallActionRow = { + actionRef: string; + actionDef: ConfiguredActionType; + runCount: number; + lastResult?: unknown; + lastRanAt?: Date; +}; +// New RunButton component +const RunButton: React.FC<{ + actionRef: string; + callId: string; + entity: string; + project: string; + refetchFeedback: () => void; + getClient: () => any; +}> = ({actionRef, callId, entity, project, refetchFeedback, getClient}) => { + const [isRunning, setIsRunning] = useState(false); + const [error, setError] = useState(null); + + const handleRunClick = async () => { + setIsRunning(true); + setError(null); + try { + await getClient().actionsExecuteBatch({ + project_id: projectIdFromParts({entity, project}), + call_ids: [callId], + configured_action_ref: actionRef, + }); + refetchFeedback(); + } catch (err) { + setError('An error occurred while running the action.'); + } finally { + setIsRunning(false); + } + }; + + if (error) { + return ( + + ); + } + + return ( +
+ +
+ ); +}; + +export const CallActionsViewer: React.FC<{ + call: CallSchema; +}> = props => { + const {useFeedback} = useWFHooks(); + const weaveRef = makeRefCall( + props.call.entity, + props.call.project, + props.call.callId + ); + const feedbackQuery = useFeedback({ + entity: props.call.entity, + project: props.call.project, + weaveRef, + }); + + const configuredActions = useCollectionObjects('ConfiguredAction', { + project_id: projectIdFromParts({ + entity: props.call.entity, + project: props.call.project, + }), + filter: {latest_only: true}, + }).sort((a, b) => a.val.name.localeCompare(b.val.name)); + const verifiedActionFeedbacks: Array<{ + data: MachineScoreFeedbackPayloadType; + feedbackRaw: Feedback; + }> = useMemo(() => { + return (feedbackQuery.result ?? []) + .map(feedback => { + const res = MachineScoreFeedbackPayloadSchema.safeParse( + feedback.payload + ); + return {res, feedbackRaw: feedback}; + }) + .filter(result => result.res.success) + .map(result => ({ + data: result.res.data, + feedbackRaw: result.feedbackRaw, + })) as Array<{ + data: MachineScoreFeedbackPayloadType; + feedbackRaw: Feedback; + }>; + }, [feedbackQuery.result]); + + const getFeedbackForAction = useCallback( + (actionRef: string) => { + return verifiedActionFeedbacks.filter( + feedback => feedback.data.runnable_ref === actionRef + ); + }, + [verifiedActionFeedbacks] + ); + + const getClient = useGetTraceServerClientContext(); + + const allCallActions: CallActionRow[] = useMemo(() => { + return ( + configuredActions?.map(configuredAction => { + const configuredActionRefUri = objectVersionKeyToRefUri({ + scheme: WEAVE_REF_SCHEME, + weaveKind: 'object', + entity: props.call.entity, + project: props.call.project, + objectId: configuredAction.object_id, + versionHash: configuredAction.digest, + path: '', + }); + const feedbacks = getFeedbackForAction(configuredActionRefUri); + const selectedFeedback = + feedbacks.length > 0 ? feedbacks[0] : undefined; + return { + actionRef: configuredActionRefUri, + actionDef: configuredAction.val, + runCount: feedbacks.length, + lastRanAt: selectedFeedback + ? convertISOToDate(selectedFeedback.feedbackRaw.created_at + 'Z') + : undefined, + lastResult: selectedFeedback + ? getValueFromMachineScoreFeedbackPayload(selectedFeedback.data) + : undefined, + }; + }) ?? [] + ); + }, [ + configuredActions, + getFeedbackForAction, + props.call.entity, + props.call.project, + ]); + + const columns = [ + {field: 'action', headerName: 'Action', flex: 1}, + {field: 'runCount', headerName: 'Run Count', flex: 1}, + { + field: 'lastResult', + headerName: 'Last Result', + flex: 1, + renderCell: (params: any) => { + const value = params.row.lastResult; + if (value == null) { + return ; + } + return ; + }, + }, + { + field: 'lastRanAt', + headerName: 'Last Ran At', + flex: 1, + renderCell: (params: any) => { + const value = params.row.lastRanAt + ? params.row.lastRanAt.getTime() / 1000 + : undefined; + if (value == null) { + return ; + } + return ; + }, + }, + { + field: 'run', + headerName: 'Run', + flex: 1, + renderCell: (params: any) => ( + + ), + }, + ]; + + const rows = allCallActions.map((action, index) => ({ + id: index, + action: action.actionDef.name, + runCount: action.runCount, + lastResult: action.lastResult, + lastRanAt: action.lastRanAt, + actionRef: action.actionRef, + })); + return ( + <> + + + ); +}; + +const MachineScoreFeedbackPayloadSchema = z.object({ + // _type: z.literal("ActionFeedback"), + runnable_ref: z.string(), + call_ref: z.string().optional(), + trigger_ref: z.string().optional(), + value: z.record(z.string(), z.record(z.string(), z.boolean())), +}); + +type MachineScoreFeedbackPayloadType = z.infer< + typeof MachineScoreFeedbackPayloadSchema +>; + +const getValueFromMachineScoreFeedbackPayload = ( + payload: MachineScoreFeedbackPayloadType +) => { + const ref = parseRef(payload.runnable_ref); + const name = ref.artifactName; + const digest = ref.artifactVersion; + return payload.value[name][digest]; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index bd56e5c830d..5d58634d48c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -16,6 +16,7 @@ import { } from '../../context'; import {FeedbackGrid} from '../../feedback/FeedbackGrid'; import {NotFoundPanel} from '../../NotFoundPanel'; +import {ENABLE_ONLINE_EVAL_UI, getFeatureFlag} from '../../windowFlags'; import {isCallChat} from '../ChatView/hooks'; import {isEvaluateOp} from '../common/heuristics'; import {CenteredAnimatedLoader} from '../common/Loader'; @@ -28,11 +29,13 @@ import {TabUseCall} from '../TabUseCall'; import {useURLSearchParamsDict} from '../util'; import {useWFHooks} from '../wfReactInterface/context'; import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; +import {CallActionsViewer} from './CallActionsViewer'; import {CallChat} from './CallChat'; import {CallDetails} from './CallDetails'; import {CallOverview} from './CallOverview'; import {CallSummary} from './CallSummary'; import {CallTraceView, useCallFlattenedTraceTree} from './CallTraceView'; + export const CallPage: FC<{ entity: string; project: string; @@ -59,6 +62,7 @@ const useCallTabs = (call: CallSchema) => { const codeURI = call.opVersionRef; const {entity, project, callId} = call; const weaveRef = makeRefCall(entity, project, callId); + const enableOnlineEvalUI = getFeatureFlag(ENABLE_ONLINE_EVAL_UI); return [ // Disabling Evaluation tab until it's better for single evaluation ...(false && isEvaluateOp(call.spanName) @@ -127,6 +131,18 @@ const useCallTabs = (call: CallSchema) => { ), }, + ...(enableOnlineEvalUI + ? [ + { + label: 'Scores', + content: ( + + + + ), + }, + ] + : []), { label: 'Use', content: ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx index 1a918d22541..a21e47e3917 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx @@ -73,7 +73,7 @@ type RefValues = Record; // ref URI to value type TruncatedStore = {[key: string]: {values: any; index: number}}; -const RESOLVED_REF_KEY = '_ref'; +const RESOVLED_REF_KEY = '_ref'; export const ARRAY_TRUNCATION_LENGTH = 50; const TRUNCATION_KEY = '__weave_array_truncated__'; @@ -163,13 +163,13 @@ export const ObjectViewer = ({ if (typeof val === 'object' && val !== null) { val = { ...v, - [RESOLVED_REF_KEY]: r, + [RESOVLED_REF_KEY]: r, }; } else { // This makes it so that runs pointing to primitives can still be expanded in the table. val = { '': v, - [RESOLVED_REF_KEY]: r, + [RESOVLED_REF_KEY]: r, }; } } @@ -182,7 +182,7 @@ export const ObjectViewer = ({ isWeaveRef(context.value) && refValues[context.value] != null && // Don't expand _ref keys - context.path.tail() !== RESOLVED_REF_KEY + context.path.tail() !== RESOVLED_REF_KEY ) { dirty = true; return refValues[context.value]; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx index 9481cf114ae..4b95e9b2ca0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx @@ -53,6 +53,8 @@ const OBJECT_ICONS: Record = { Dataset: 'table', Evaluation: 'baseline-alt', Leaderboard: 'benchmark-square', + Scorer: 'type-number-alt', + ConfiguredAction: 'rocket-launch', }; const ObjectIcon = ({baseObjectClass}: ObjectIconProps) => { if (baseObjectClass in OBJECT_ICONS) { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpOnlineScorersTab.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpOnlineScorersTab.tsx new file mode 100644 index 00000000000..49ad30e69e4 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpOnlineScorersTab.tsx @@ -0,0 +1,289 @@ +import {Box} from '@material-ui/core'; +import {Button} from '@wandb/weave/components/Button/Button'; +import {parseRef} from '@wandb/weave/react'; +import React, {FC, useMemo, useState} from 'react'; +import {z} from 'zod'; + +import {SmallRef} from '../../../Browse2/SmallRef'; +import { + ActionDispatchFilterSchema, + ActionDispatchFilterType, +} from '../../collections/actionCollection'; +import {collectionRegistry} from '../../collections/collectionRegistry'; +import { + useCollectionObjects, + useCreateCollectionObject, +} from '../../collections/getCollectionObjects'; +import {DynamicConfigForm} from '../../DynamicConfigForm'; +import {ReusableDrawer} from '../../ReusableDrawer'; +import {StyledDataGrid} from '../../StyledDataGrid'; +import {TraceObjSchema} from '../wfReactInterface/traceServerClientTypes'; +import { + convertISOToDate, + projectIdFromParts, +} from '../wfReactInterface/tsDataModelHooks'; +import {objectVersionKeyToRefUri} from '../wfReactInterface/utilities'; +import {OpVersionSchema} from '../wfReactInterface/wfDataModelHooksInterface'; + +type OnlineScorerType = TraceObjSchema; + +const useOnlineScorersForOpVersion = ( + opVersion: OpVersionSchema +): {scorers: OnlineScorerType[]; refresh: () => void} => { + const [poorMansRefreshCount, setPoorMansRefreshCount] = useState(0); + const req = useMemo(() => { + return { + project_id: projectIdFromParts({ + entity: opVersion.entity, + project: opVersion.project, + }), + filter: { + latest_only: true, + }, + poorMansRefreshCount, + }; + }, [opVersion.entity, opVersion.project, poorMansRefreshCount]); + const scorers = useCollectionObjects('ActionDispatchFilter', req).sort( + (a, b) => { + return ( + convertISOToDate(a.created_at).getTime() - + convertISOToDate(b.created_at).getTime() + ); + } + ); + const refresh = () => { + setPoorMansRefreshCount(poorMansRefreshCount + 1); + }; + return {scorers, refresh}; +}; + +export const OpOnlineScorersTab: React.FC<{ + opVersion: OpVersionSchema; +}> = ({opVersion}) => { + const req = useMemo(() => { + return { + project_id: projectIdFromParts({ + entity: opVersion.entity, + project: opVersion.project, + }), + filter: { + latest_only: true, + }, + }; + }, [opVersion.entity, opVersion.project]); + const availableActions = useCollectionObjects('ConfiguredAction', req); + const [isModalOpen, setIsModalOpen] = useState(false); + const {scorers: onlineScorers, refresh} = + useOnlineScorersForOpVersion(opVersion); + + const handleOpenModal = () => { + setIsModalOpen(true); + }; + + const handleCloseModal = (didSave: boolean) => { + refresh(); + setIsModalOpen(false); + }; + + const columns = [ + { + field: 'name', + headerName: 'Name', + flex: 1, + renderCell: (params: any) => { + return ; + }, + }, + {field: 'sampleRate', headerName: 'Sample Rate', flex: 1}, + { + field: 'configuredActionRef', + headerName: 'Configured Action', + flex: 1, + renderCell: (params: any) => { + return ; + }, + }, + {field: 'disabled', headerName: 'Disabled', flex: 1}, + ]; + + const rows = onlineScorers + .filter(scorer => scorer.val.op_name === opVersion.opId) + .map((scorer, index) => { + const scorerRef = objectVersionKeyToRefUri({ + scheme: 'weave', + weaveKind: 'object', + entity: opVersion.entity, + project: opVersion.project, + objectId: scorer.object_id, + versionHash: scorer.digest, + path: '', + }); + return { + id: scorerRef, + name: scorerRef, + createdAt: convertISOToDate(scorer.created_at), + disabled: scorer.val.disabled, + sampleRate: scorer.val.sample_rate, + configuredActionRef: scorer.val.configured_action_ref, + // Map other fields as needed + }; + }); + + const actionRefs = useMemo(() => { + return availableActions.map(action => { + return objectVersionKeyToRefUri({ + scheme: 'weave', + weaveKind: 'object', + entity: opVersion.entity, + project: opVersion.project, + objectId: action.object_id, + versionHash: action.digest, + path: '', + }); + }); + }, [availableActions, opVersion.entity, opVersion.project]); + + const inputSchema = useMemo(() => { + const base = ActionDispatchFilterSchema.merge( + z.object({ + op_name: z.literal(opVersion.opId), + }) + ); + if (actionRefs.length === 0) { + return base; + } + + return base.merge( + z.object({ + configured_action_ref: z.enum( + actionRefs as unknown as [string, ...string[]] + ), + }) + ); + }, [actionRefs, opVersion.opId]); + + return ( + + + + + + + + ); +}; + +interface NewOnlineOpScorerModalProps { + entity: string; + project: string; + collectionDef: { + name: keyof typeof collectionRegistry; + schema: z.Schema; + }; + isOpen: boolean; + onClose: (didSave: boolean) => void; +} + +export const NewOnlineOpScorerModal: FC = ({ + entity, + project, + collectionDef, + isOpen, + onClose, +}) => { + const [config, setConfig] = useState>({}); + + const createCollectionObject = useCreateCollectionObject(collectionDef.name); + + const handleSaveModal = (newAction: Record) => { + const parsedAction = collectionDef.schema.safeParse(newAction); + if (!parsedAction.success) { + console.error( + `Invalid action: ${JSON.stringify(parsedAction.error.errors)}` + ); + return; + } + const opName = parsedAction.data.op_name; + const actionRef = parsedAction.data.configured_action_ref; + const actionName = parseRef(actionRef).artifactName; + let objectId = `${opName}-${actionName}`; + // Remove non alphanumeric characters + objectId = objectId.replace(/[^a-zA-Z0-9]/g, '-'); + createCollectionObject({ + obj: { + project_id: projectIdFromParts({entity, project}), + object_id: objectId, + val: parsedAction.data, + }, + }) + .catch(err => { + console.error(err); + }) + .finally(() => { + setConfig({}); + onClose(true); + }); + }; + + const [isValid, setIsValid] = useState(false); + + return ( + onClose(false)} + onSave={() => handleSaveModal(config)} + saveDisabled={!isValid}> + + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx new file mode 100644 index 00000000000..1be7e32dfc0 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx @@ -0,0 +1,163 @@ +import React, {useMemo} from 'react'; + +import {LoadingDots} from '../../../../../LoadingDots'; +import {Tailwind} from '../../../../../Tailwind'; +import {NotFoundPanel} from '../../NotFoundPanel'; +import {OpCodeViewer} from '../../OpCodeViewer'; +import {ENABLE_ONLINE_EVAL_UI, getFeatureFlag} from '../../windowFlags'; +import { + CallsLink, + opNiceName, + OpVersionsLink, + opVersionText, +} from '../common/Links'; +import {CenteredAnimatedLoader} from '../common/Loader'; +import { + ScrollableTabContent, + SimpleKeyValueTable, + SimplePageLayoutWithHeader, +} from '../common/SimplePageLayout'; +import {TabUseOp} from '../TabUseOp'; +import {useWFHooks} from '../wfReactInterface/context'; +import {opVersionKeyToRefUri} from '../wfReactInterface/utilities'; +import {OpVersionSchema} from '../wfReactInterface/wfDataModelHooksInterface'; +import {OpOnlineScorersTab} from './OpOnlineScorersTab'; + +export const OpVersionPage: React.FC<{ + entity: string; + project: string; + opName: string; + version: string; +}> = props => { + const {useOpVersion} = useWFHooks(); + + const opVersion = useOpVersion({ + entity: props.entity, + project: props.project, + opId: props.opName, + versionHash: props.version, + }); + if (opVersion.loading) { + return ; + } else if (opVersion.result == null) { + return ; + } + return ; +}; + +const OpVersionPageInner: React.FC<{ + opVersion: OpVersionSchema; +}> = ({opVersion}) => { + const {useOpVersions, useCallsStats} = useWFHooks(); + const uri = opVersionKeyToRefUri(opVersion); + const {entity, project, opId, versionIndex} = opVersion; + + const opVersions = useOpVersions( + entity, + project, + { + opIds: [opId], + }, + undefined, // limit + true // metadataOnly + ); + const opVersionCount = (opVersions.result ?? []).length; + const callsStats = useCallsStats(entity, project, { + opVersionRefs: [uri], + }); + const opVersionCallCount = callsStats?.result?.count ?? 0; + const useOpSupported = useMemo(() => { + // TODO: We really want to return `True` only when + // the op is not a bound op. However, we don't have + // that data available yet. + return true; + }, []); + const enableOnlineEvalUI = getFeatureFlag(ENABLE_ONLINE_EVAL_UI); + + return ( + + {opId}{' '} + {opVersions.loading ? ( + + ) : ( + <> + [ + + ] + + )} + + ), + Version: <>{versionIndex}, + Calls: + !callsStats.loading || opVersionCallCount > 0 ? ( + + ) : ( + <> + ), + }} + /> + } + tabs={[ + { + label: 'Code', + content: ( + + ), + }, + ...(enableOnlineEvalUI + ? [ + { + label: 'Online Scorers', + content: , + }, + ] + : []), + ...(useOpSupported + ? [ + { + label: 'Use', + content: ( + + + + + + ), + }, + ] + : []), + ]} + /> + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx new file mode 100644 index 00000000000..f332e36d77c --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx @@ -0,0 +1,155 @@ +import { + FormControl, + InputLabel, + MenuItem, + Select, + TextField, +} from '@material-ui/core'; +import _ from 'lodash'; +import React, {FC, useEffect, useState} from 'react'; +import {z} from 'zod'; + +import { + ConfiguredActionSchema, + ConfiguredActionType, + ConfiguredLlmJudgeActionSchema, + ConfiguredWordCountActionSchema, +} from '../../collections/actionCollection'; +import {DynamicConfigForm} from '../../DynamicConfigForm'; +import {ReusableDrawer} from '../../ReusableDrawer'; +import { + actionTemplates, + ConfiguredLlmJudgeActionFriendlySchema, +} from './actionTemplates'; + +const knownBuiltinActions = [ + { + name: 'LLM Judge', + actionSchema: ConfiguredLlmJudgeActionSchema, + friendly: { + schema: ConfiguredLlmJudgeActionFriendlySchema, + convert: ( + data: z.infer + ): z.infer => { + let responseFormat: z.infer< + typeof ConfiguredLlmJudgeActionSchema + >['response_format']; + if (data.response_format.type === 'simple') { + responseFormat = {type: data.response_format.schema}; + } else { + responseFormat = { + type: 'object', + properties: _.mapValues(data.response_format.schema, value => ({ + type: value as 'boolean' | 'number' | 'string', + })), + additionalProperties: false, + }; + } + return { + action_type: 'llm_judge', + model: data.model, + prompt: data.prompt, + response_format: responseFormat, + }; + }, + }, + }, + { + name: 'Word Count', + actionSchema: ConfiguredWordCountActionSchema, + friendly: { + schema: z.object({}), + convert: ( + data: z.infer + ): z.infer => { + return { + action_type: 'wordcount', + }; + }, + }, + }, +]; + +interface NewBuiltInActionScorerModalProps { + open: boolean; + onClose: () => void; + onSave: (newAction: ConfiguredActionType) => void; + initialTemplate: string; +} + +export const NewBuiltInActionScorerModal: FC< + NewBuiltInActionScorerModalProps +> = ({open, onClose, onSave, initialTemplate}) => { + const [name, setName] = useState(''); + const [selectedActionIndex, setSelectedActionIndex] = useState(0); + const [config, setConfig] = useState>({}); + + useEffect(() => { + if (initialTemplate) { + const template = actionTemplates.find(t => t.name === initialTemplate); + if (template) { + setConfig(template.type); + setName(template.name); + } + } else { + setConfig({}); + setName(''); + } + }, [initialTemplate]); + + const handleSave = () => { + const newAction = ConfiguredActionSchema.parse({ + name, + config: knownBuiltinActions[selectedActionIndex].friendly.convert( + config as any + ), + }); + onSave(newAction); + setConfig({}); + setSelectedActionIndex(0); + setName(''); + }; + + const [isValid, setIsValid] = useState(false); + + return ( + + setName(e.target.value)} + margin="normal" + /> + + Action Type + + + {selectedActionIndex !== -1 && ( + + )} + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx new file mode 100644 index 00000000000..585d0c7e47f --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx @@ -0,0 +1,219 @@ +import {Box} from '@material-ui/core'; +import {Alert} from '@mui/material'; +import Menu from '@mui/material/Menu'; +import MenuItem from '@mui/material/MenuItem'; +import {Button} from '@wandb/weave/components/Button/Button'; +import React, {FC, useState} from 'react'; + +import {ConfiguredActionType} from '../../collections/actionCollection'; +import {useCreateCollectionObject} from '../../collections/getCollectionObjects'; +import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; +import {FilterableObjectVersionsTable} from '../ObjectVersionsPage'; +import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; +import {actionTemplates} from './actionTemplates'; +import {NewBuiltInActionScorerModal} from './NewBuiltInActionScorerModal'; + +export const ScorersPage: React.FC<{ + entity: string; + project: string; +}> = ({entity, project}) => { + return ( + , + }, + { + label: 'Human Review', + content: , + }, + { + label: 'Code Scorers', + content: , + }, + ]} + headerContent={undefined} + /> + ); +}; + +const CodeScorersTab: React.FC<{ + entity: string; + project: string; +}> = ({entity, project}) => { + return ( + + ); +}; + +const HumanScorersTab: React.FC<{ + entity: string; + project: string; +}> = ({entity, project}) => { + return ( + + Human Review coming soon + + ); + // return ( + // + // ); +}; + +const OnlineScorersTab: React.FC<{ + entity: string; + project: string; +}> = ({entity, project}) => { + const [isModalOpen, setIsModalOpen] = useState(false); + const [selectedTemplate, setSelectedTemplate] = useState(''); + const createCollectionObject = useCreateCollectionObject('ConfiguredAction'); + const [lastUpdatedTimestamp, setLastUpdatedTimestamp] = useState(0); + + const [anchorEl, setAnchorEl] = React.useState(null); + const open = Boolean(anchorEl); + + const handleCreateBlank = () => { + setSelectedTemplate(''); + setIsModalOpen(true); + }; + + const handleDropdownClick = (event: React.MouseEvent) => { + setAnchorEl(event.currentTarget); + }; + + const handleClose = () => { + setAnchorEl(null); + }; + + const handleTemplateSelect = (templateName: string) => { + setSelectedTemplate(templateName); + setIsModalOpen(true); + handleClose(); + }; + + const handleCloseModal = () => { + setIsModalOpen(false); + setSelectedTemplate(''); + }; + + const handleSaveModal = (newAction: ConfiguredActionType) => { + let objectId = newAction.name; + // Remove non alphanumeric characters + objectId = objectId.replace(/[^a-zA-Z0-9]/g, '-'); + createCollectionObject({ + obj: { + project_id: projectIdFromParts({entity, project}), + object_id: objectId, + val: newAction, + }, + }) + .then(() => { + setLastUpdatedTimestamp(Date.now()); + }) + .catch(err => { + console.error(err); + }) + .finally(() => { + handleCloseModal(); + }); + }; + + return ( + + + + + + +); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionTemplates.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionTemplates.tsx new file mode 100644 index 00000000000..51f7ca32bcd --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionTemplates.tsx @@ -0,0 +1,58 @@ +import {z} from 'zod'; + +const SimpleResponseFormatSchema = z + .enum(['boolean', 'number', 'string']) + .default('boolean'); +const StructuredResponseFormatSchema = z.record(SimpleResponseFormatSchema); + +const ResponseFormatSchema = z.discriminatedUnion('type', [ + z.object({ + type: z.literal('simple'), + schema: SimpleResponseFormatSchema, + }), + z.object({ + type: z.literal('structured'), + schema: StructuredResponseFormatSchema, + }), +]); + +export const ConfiguredLlmJudgeActionFriendlySchema = z.object({ + model: z.enum(['gpt-4o-mini', 'gpt-4o']).default('gpt-4o-mini'), + prompt: z.string(), + response_format: ResponseFormatSchema, +}); +type ConfiguredLlmJudgeActionFriendlyType = z.infer< + typeof ConfiguredLlmJudgeActionFriendlySchema +>; + +export const actionTemplates: Array<{ + name: string; + type: ConfiguredLlmJudgeActionFriendlyType; +}> = [ + { + name: 'RelevancyJudge', + type: { + model: 'gpt-4o-mini', + prompt: 'Is the output relevant to the input?', + response_format: { + type: 'simple', + schema: 'boolean', + }, + }, + }, + { + name: 'CorrectnessJudge', + type: { + model: 'gpt-4o-mini', + prompt: + 'Given the input and output, and your knowledge of the world, is the output correct?', + response_format: { + type: 'structured', + schema: { + is_correct: 'boolean', + reason: 'string', + }, + }, + }, + }, +]; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx index e4f08956eb9..fb9756fc7cc 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx @@ -272,21 +272,23 @@ export const SimplePageLayoutWithHeader: FC<{ {props.headerContent} {(!props.hideTabsIfSingle || tabs.length > 1) && ( - - - {tabs.map(tab => ( - - {tab.label} - - ))} - - + + + + {tabs.map(tab => ( + + {tab.label} + + ))} + + + )} {key} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx index c9609197b9e..ec90e40ecb5 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx @@ -9,6 +9,8 @@ const colorMap: Record = { Dataset: 'green', Evaluation: 'cactus', Leaderboard: 'gold', + Scorer: 'purple', + ConfiguredAction: 'sienna', }; export const TypeVersionCategoryChip: React.FC<{ diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts index 53f57f4ba09..f9a9b5cda8d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts @@ -25,4 +25,6 @@ export const KNOWN_BASE_OBJECT_CLASSES = [ 'Dataset', 'Evaluation', 'Leaderboard', + 'Scorer', + 'ConfiguredAction', ] as const; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts index d7219af5834..21565df0533 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts @@ -306,3 +306,11 @@ export const fileExtensions = { [ContentType.any]: 'jsonl', [ContentType.json]: 'json', }; + +export type ActionsExecuteBatchReq = { + project_id: string; + call_ids: string[]; + configured_action_ref: string; +}; + +export type ActionsExecuteBatchRes = {}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts index 25f11624740..b4bc850edbb 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts @@ -16,6 +16,8 @@ import {getCookie} from '@wandb/weave/common/util/cookie'; import fetch from 'isomorphic-unfetch'; import { + ActionsExecuteBatchReq, + ActionsExecuteBatchRes, ContentType, FeedbackCreateReq, FeedbackCreateRes, @@ -226,13 +228,6 @@ export class DirectTraceServerClient { return this.makeRequest('/obj/read', req); } - public readBatch(req: TraceRefsReadBatchReq): Promise { - return this.makeRequest( - '/refs/read_batch', - req - ); - } - public objCreate(req: TraceObjCreateReq): Promise { const initialObjectId = req.obj.object_id; const sanitizedObjectId = sanitizeObjectId(initialObjectId); @@ -249,6 +244,13 @@ export class DirectTraceServerClient { ); } + public readBatch(req: TraceRefsReadBatchReq): Promise { + return this.makeRequest( + '/refs/read_batch', + req + ); + } + public tableQuery(req: TraceTableQueryReq): Promise { return this.makeRequest( '/table/query', @@ -286,6 +288,15 @@ export class DirectTraceServerClient { ); } + public actionsExecuteBatch( + req: ActionsExecuteBatchReq + ): Promise { + return this.makeRequest( + '/actions/execute_batch', + req + ); + } + public fileContent( req: TraceFileContentReadReq ): Promise { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/windowFlags.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/windowFlags.tsx new file mode 100644 index 00000000000..6879c66ead0 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/windowFlags.tsx @@ -0,0 +1,103 @@ +/** + * Feature Flag Management Utility + * + * This module provides a simple feature flag management system for toggling beta features. + * It allows developers to specify a list of flags, get the current state of any given flag, + * and enables users to toggle flags via the browser console. + * + * Usage: + * 1. Initialize flags at the app's entry point: + * import { initializeFlags } from './windowFlags'; + * initializeFlags(['BETA_FEATURE_1', 'BETA_FEATURE_2']); + * + * 2. Use the hook in React components: + * import { useFeatureFlag } from './windowFlags'; + * const MyComponent = () => { + * const isBetaFeature1Enabled = useFeatureFlag('BETA_FEATURE_1'); + * return isBetaFeature1Enabled ? : null; + * }; + * + * 3. Toggle flags from the browser console: + * window.setFeatureFlag('BETA_FEATURE_1', true); + * + * This system allows for easy management of feature flags and real-time updates + * in React components when flags are toggled. + */ + +import {useEffect, useState} from 'react'; + +// Define the type for the flags +type FeatureFlags = { + [key: string]: boolean; +}; + +// Initialize the flags on the window object +declare global { + interface Window { + featureFlags: FeatureFlags; + setFeatureFlag: (flagName: string, value: boolean) => void; + } +} + +// Initialize the feature flags +const initializeFeatureFlags = (flags: string[]) => { + window.featureFlags = flags.reduce((acc, flag) => { + acc[flag] = false; + return acc; + }, {} as FeatureFlags); + + // Expose a method to set flags + window.setFeatureFlag = (flagName: string, value: boolean) => { + if (flagName in window.featureFlags) { + window.featureFlags[flagName] = value; + // Dispatch a custom event when a flag is changed + window.dispatchEvent( + new CustomEvent('featureFlagChanged', {detail: {flagName, value}}) + ); + } else { + console.warn(`Feature flag "${flagName}" is not defined.`); + } + }; +}; + +// Function to get the current state of a flag +export const getFeatureFlag = (flagName: string): boolean => { + return window.featureFlags?.[flagName] ?? false; +}; + +// React hook to use feature flags in components +export const useFeatureFlag = (flagName: string): boolean => { + const [flagValue, setFlagValue] = useState(getFeatureFlag(flagName)); + + useEffect(() => { + const handleFlagChange = (event: CustomEvent) => { + if (event.detail.flagName === flagName) { + setFlagValue(event.detail.value); + } + }; + window.addEventListener( + 'featureFlagChanged', + handleFlagChange as EventListener + ); + + return () => { + window.removeEventListener( + 'featureFlagChanged', + handleFlagChange as EventListener + ); + }; + }, [flagName]); + + return flagValue; +}; + +// Function to initialize the feature flags +export const initializeFlags = (flags: string[]) => { + initializeFeatureFlags(flags); +}; + +export const ENABLE_ONLINE_EVAL_UI = 'ENABLE_ONLINE_EVAL_UI'; +initializeFlags([ENABLE_ONLINE_EVAL_UI]); + +// Set the feature flag to true for testing +window.setFeatureFlag('ENABLE_ONLINE_EVAL_UI', true); diff --git a/weave-js/src/components/Panel2/PanelTable/tableState.ts b/weave-js/src/components/Panel2/PanelTable/tableState.ts index 4cb8d286798..565fbb4b7f3 100644 --- a/weave-js/src/components/Panel2/PanelTable/tableState.ts +++ b/weave-js/src/components/Panel2/PanelTable/tableState.ts @@ -1,6 +1,5 @@ import { allObjPaths, - canSortType, constFunction, ConstNode, constNodeUnsafe, @@ -688,16 +687,6 @@ export async function disableGroupByCol( ) { const colIds = _.isArray(colId) ? colId : [colId]; const groupBy = ts.groupBy; - - // (WB-16067) - // We may try to sort on aggregated columns after ungrouping - // To prevent this, disable sorting on all the columns and re-enable - // after the ungroup - const initiallySortedCols = _.clone(ts.sort); - ts.sort.forEach(sortObj => { - ts = disableSortByCol(ts, sortObj.columnId); - }); - ts = produce(ts, draft => { draft.autoColumns = false; for (const cid of colIds) { @@ -712,15 +701,9 @@ export async function disableGroupByCol( } }); ts = await refreshSelectFunctions(ts, inputArrayNode, weave, stack); - - initiallySortedCols.forEach(sortObj => { - if ( - sortObj.columnId !== colId && - canSortType(ts.columnSelectFunctions[sortObj.columnId].type) - ) { - ts = enableSortByCol(ts, sortObj.columnId, sortObj.dir === 'asc'); - } - }); + if (ts.sort.find(s => s.columnId === colId) !== undefined) { + ts = disableSortByCol(ts, colId); + } return ts; } From abfa4b9c4efc836f6f11b124aaccfea3f72d6a79 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 19:48:54 -0800 Subject: [PATCH 044/120] Moved actions worker --- .../actions_worker/actions/contains_words.py | 0 weave/{ => trace_server}/actions_worker/actions/llm_judge.py | 0 weave/{ => trace_server}/actions_worker/dispatcher.py | 4 ++-- weave/trace_server/clickhouse_trace_server_batched.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename weave/{ => trace_server}/actions_worker/actions/contains_words.py (100%) rename weave/{ => trace_server}/actions_worker/actions/llm_judge.py (100%) rename weave/{ => trace_server}/actions_worker/dispatcher.py (96%) diff --git a/weave/actions_worker/actions/contains_words.py b/weave/trace_server/actions_worker/actions/contains_words.py similarity index 100% rename from weave/actions_worker/actions/contains_words.py rename to weave/trace_server/actions_worker/actions/contains_words.py diff --git a/weave/actions_worker/actions/llm_judge.py b/weave/trace_server/actions_worker/actions/llm_judge.py similarity index 100% rename from weave/actions_worker/actions/llm_judge.py rename to weave/trace_server/actions_worker/actions/llm_judge.py diff --git a/weave/actions_worker/dispatcher.py b/weave/trace_server/actions_worker/dispatcher.py similarity index 96% rename from weave/actions_worker/dispatcher.py rename to weave/trace_server/actions_worker/dispatcher.py index de9908f405d..90f06d02085 100644 --- a/weave/actions_worker/dispatcher.py +++ b/weave/trace_server/actions_worker/dispatcher.py @@ -2,10 +2,10 @@ from pydantic import BaseModel -from weave.actions_worker.actions.contains_words import ( +from weave.trace_server.actions_worker.actions.contains_words import ( do_contains_words_action, ) -from weave.actions_worker.actions.llm_judge import do_llm_judge_action +from weave.trace_server.actions_worker.actions.llm_judge import do_llm_judge_action from weave.trace_server.interface.base_object_classes.actions import ( ActionDefinition, ActionSpecType, diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 17d54f3fca2..7d7cac48b70 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -45,11 +45,11 @@ from clickhouse_connect.driver.query import QueryResult from clickhouse_connect.driver.summary import QuerySummary -from weave.actions_worker.dispatcher import execute_batch from weave.trace_server import clickhouse_trace_server_migrator as wf_migrator from weave.trace_server import environment as wf_env from weave.trace_server import refs_internal as ri from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.actions_worker.dispatcher import execute_batch from weave.trace_server.base_object_class_util import ( process_incoming_object, ) From 9e1ee02808c222bb99951ffed989390b48f60c78 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 20:07:12 -0800 Subject: [PATCH 045/120] Remove window flags --- .../components/FancyPage/useProjectSidebar.ts | 8 +- .../PagePanelComponents/Home/Browse3.tsx | 5 - .../Home/Browse3/pages/CallPage/CallPage.tsx | 22 ++-- .../pages/OpVersionPage/OpVersionPage.tsx | 14 +-- .../Home/Browse3/windowFlags.tsx | 103 ------------------ 5 files changed, 13 insertions(+), 139 deletions(-) delete mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/windowFlags.tsx diff --git a/weave-js/src/components/FancyPage/useProjectSidebar.ts b/weave-js/src/components/FancyPage/useProjectSidebar.ts index c33b6bfe90a..45ded11435d 100644 --- a/weave-js/src/components/FancyPage/useProjectSidebar.ts +++ b/weave-js/src/components/FancyPage/useProjectSidebar.ts @@ -2,10 +2,6 @@ import {IconNames} from '@wandb/weave/components/Icon'; import _ from 'lodash'; import {useMemo} from 'react'; -import { - ENABLE_ONLINE_EVAL_UI, - getFeatureFlag, -} from '../PagePanelComponents/Home/Browse3/windowFlags'; import {FancyPageSidebarItem} from './FancyPageSidebar'; export const useProjectSidebar = ( @@ -35,7 +31,6 @@ export const useProjectSidebar = ( const isNoSidebarItems = !showModelsSidebarItems && !showWeaveSidebarItems; const isBothSidebarItems = showModelsSidebarItems && showWeaveSidebarItems; const isShowAll = isNoSidebarItems || isBothSidebarItems; - const enableOnlineEvalUI = getFeatureFlag(ENABLE_ONLINE_EVAL_UI); return useMemo(() => { const allItems = isLoading @@ -192,7 +187,7 @@ export const useProjectSidebar = ( type: 'button' as const, name: 'Scorers', slug: 'weave/scorers', - isShown: enableOnlineEvalUI && (showWeaveSidebarItems || isShowAll), + isShown: showWeaveSidebarItems || isShowAll, iconName: IconNames.TypeNumberAlt, }, { @@ -255,6 +250,5 @@ export const useProjectSidebar = ( viewingRestricted, isModelsOnly, showWeaveSidebarItems, - enableOnlineEvalUI, ]); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index f73b62447ce..c2ed51b325f 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -103,7 +103,6 @@ import { WFDataModelAutoProvider, } from './Browse3/pages/wfReactInterface/context'; import {useHasTraceServerClientContext} from './Browse3/pages/wfReactInterface/traceServerClientContext'; -import {ENABLE_ONLINE_EVAL_UI, getFeatureFlag} from './Browse3/windowFlags'; import {useDrawerResize} from './useDrawerResize'; LicenseInfo.setLicenseKey( @@ -993,10 +992,6 @@ const CompareEvaluationsBinding = () => { const ScorersPageBinding = () => { const {entity, project} = useParamsDecoded(); - const enableOnlineEvalUI = getFeatureFlag(ENABLE_ONLINE_EVAL_UI); - if (!enableOnlineEvalUI) { - return null; - } return ; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index 5d58634d48c..4675d5b3e27 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -16,7 +16,6 @@ import { } from '../../context'; import {FeedbackGrid} from '../../feedback/FeedbackGrid'; import {NotFoundPanel} from '../../NotFoundPanel'; -import {ENABLE_ONLINE_EVAL_UI, getFeatureFlag} from '../../windowFlags'; import {isCallChat} from '../ChatView/hooks'; import {isEvaluateOp} from '../common/heuristics'; import {CenteredAnimatedLoader} from '../common/Loader'; @@ -62,7 +61,6 @@ const useCallTabs = (call: CallSchema) => { const codeURI = call.opVersionRef; const {entity, project, callId} = call; const weaveRef = makeRefCall(entity, project, callId); - const enableOnlineEvalUI = getFeatureFlag(ENABLE_ONLINE_EVAL_UI); return [ // Disabling Evaluation tab until it's better for single evaluation ...(false && isEvaluateOp(call.spanName) @@ -131,18 +129,14 @@ const useCallTabs = (call: CallSchema) => { ), }, - ...(enableOnlineEvalUI - ? [ - { - label: 'Scores', - content: ( - - - - ), - }, - ] - : []), + { + label: 'Scores', + content: ( + + + + ), + }, { label: 'Use', content: ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx index 1be7e32dfc0..530ac5e995b 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx @@ -4,7 +4,6 @@ import {LoadingDots} from '../../../../../LoadingDots'; import {Tailwind} from '../../../../../Tailwind'; import {NotFoundPanel} from '../../NotFoundPanel'; import {OpCodeViewer} from '../../OpCodeViewer'; -import {ENABLE_ONLINE_EVAL_UI, getFeatureFlag} from '../../windowFlags'; import { CallsLink, opNiceName, @@ -72,7 +71,6 @@ const OpVersionPageInner: React.FC<{ // that data available yet. return true; }, []); - const enableOnlineEvalUI = getFeatureFlag(ENABLE_ONLINE_EVAL_UI); return ( ), }, - ...(enableOnlineEvalUI - ? [ - { - label: 'Online Scorers', - content: , - }, - ] - : []), + { + label: 'Online Scorers', + content: , + }, ...(useOpSupported ? [ { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/windowFlags.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/windowFlags.tsx deleted file mode 100644 index 6879c66ead0..00000000000 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/windowFlags.tsx +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Feature Flag Management Utility - * - * This module provides a simple feature flag management system for toggling beta features. - * It allows developers to specify a list of flags, get the current state of any given flag, - * and enables users to toggle flags via the browser console. - * - * Usage: - * 1. Initialize flags at the app's entry point: - * import { initializeFlags } from './windowFlags'; - * initializeFlags(['BETA_FEATURE_1', 'BETA_FEATURE_2']); - * - * 2. Use the hook in React components: - * import { useFeatureFlag } from './windowFlags'; - * const MyComponent = () => { - * const isBetaFeature1Enabled = useFeatureFlag('BETA_FEATURE_1'); - * return isBetaFeature1Enabled ? : null; - * }; - * - * 3. Toggle flags from the browser console: - * window.setFeatureFlag('BETA_FEATURE_1', true); - * - * This system allows for easy management of feature flags and real-time updates - * in React components when flags are toggled. - */ - -import {useEffect, useState} from 'react'; - -// Define the type for the flags -type FeatureFlags = { - [key: string]: boolean; -}; - -// Initialize the flags on the window object -declare global { - interface Window { - featureFlags: FeatureFlags; - setFeatureFlag: (flagName: string, value: boolean) => void; - } -} - -// Initialize the feature flags -const initializeFeatureFlags = (flags: string[]) => { - window.featureFlags = flags.reduce((acc, flag) => { - acc[flag] = false; - return acc; - }, {} as FeatureFlags); - - // Expose a method to set flags - window.setFeatureFlag = (flagName: string, value: boolean) => { - if (flagName in window.featureFlags) { - window.featureFlags[flagName] = value; - // Dispatch a custom event when a flag is changed - window.dispatchEvent( - new CustomEvent('featureFlagChanged', {detail: {flagName, value}}) - ); - } else { - console.warn(`Feature flag "${flagName}" is not defined.`); - } - }; -}; - -// Function to get the current state of a flag -export const getFeatureFlag = (flagName: string): boolean => { - return window.featureFlags?.[flagName] ?? false; -}; - -// React hook to use feature flags in components -export const useFeatureFlag = (flagName: string): boolean => { - const [flagValue, setFlagValue] = useState(getFeatureFlag(flagName)); - - useEffect(() => { - const handleFlagChange = (event: CustomEvent) => { - if (event.detail.flagName === flagName) { - setFlagValue(event.detail.value); - } - }; - window.addEventListener( - 'featureFlagChanged', - handleFlagChange as EventListener - ); - - return () => { - window.removeEventListener( - 'featureFlagChanged', - handleFlagChange as EventListener - ); - }; - }, [flagName]); - - return flagValue; -}; - -// Function to initialize the feature flags -export const initializeFlags = (flags: string[]) => { - initializeFeatureFlags(flags); -}; - -export const ENABLE_ONLINE_EVAL_UI = 'ENABLE_ONLINE_EVAL_UI'; -initializeFlags([ENABLE_ONLINE_EVAL_UI]); - -// Set the feature flag to true for testing -window.setFeatureFlag('ENABLE_ONLINE_EVAL_UI', true); From 856ba0b134718faa51aae6ee6da337864a794d0a Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 20:10:03 -0800 Subject: [PATCH 046/120] Restore accidental changes --- weave-js/src/components/Form/AutoComplete.tsx | 4 ++++ weave-js/src/components/Form/Select.tsx | 16 ++++++++-------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/weave-js/src/components/Form/AutoComplete.tsx b/weave-js/src/components/Form/AutoComplete.tsx index 6455790fb3c..4fba92aec38 100644 --- a/weave-js/src/components/Form/AutoComplete.tsx +++ b/weave-js/src/components/Form/AutoComplete.tsx @@ -63,6 +63,10 @@ const getStyles = (props: AdditionalProps) => { minHeight: `${HEIGHTS[size]} !important`, }, }, + '& .MuiAutocomplete-popupIndicator': { + borderRadius: '4px', + padding: '4px', + }, '&.MuiAutocomplete-hasPopupIcon .MuiOutlinedInput-root, &.MuiAutocomplete-hasClearIcon .MuiOutlinedInput-root': { paddingRight: props.hasInputValue ? '28px' : '0px', // Apply padding only if input exists diff --git a/weave-js/src/components/Form/Select.tsx b/weave-js/src/components/Form/Select.tsx index 30da9ea00fe..2163b6af180 100644 --- a/weave-js/src/components/Form/Select.tsx +++ b/weave-js/src/components/Form/Select.tsx @@ -16,7 +16,8 @@ import { MOON_800, RED_550, TEAL_300, - TEAL_500, + TEAL_350, + TEAL_400, TEAL_600, } from '@wandb/weave/common/css/globals.styles'; import {Icon} from '@wandb/weave/components/Icon'; @@ -204,10 +205,8 @@ const getStyles = < }, control: (baseStyles, state) => { const colorBorderDefault = MOON_250; - const colorBorderHover = hexToRGB(TEAL_500, 0.4); - const colorBorderOpen = errorState - ? hexToRGB(RED_550, 0.64) - : hexToRGB(TEAL_500, 0.64); + const colorBorderHover = TEAL_350; + const colorBorderOpen = errorState ? hexToRGB(RED_550, 0.64) : TEAL_400; const height = HEIGHTS[size]; const minHeight = MIN_HEIGHTS[size] ?? height; const lineHeight = LINE_HEIGHTS[size]; @@ -226,9 +225,10 @@ const getStyles = < ? `0 0 0 2px ${colorBorderOpen}` : `inset 0 0 0 1px ${colorBorderDefault}`, '&:hover': { - boxShadow: state.menuIsOpen - ? `0 0 0 2px ${colorBorderOpen}` - : `0 0 0 2px ${colorBorderHover}`, + boxShadow: + state.menuIsOpen || state.isFocused + ? `0 0 0 2px ${colorBorderOpen}` + : `0 0 0 2px ${colorBorderHover}`, }, }; }, From 9999c13c74c332707ef5d30a9f5443dc79c15e94 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 20:15:13 -0800 Subject: [PATCH 047/120] Fixed merge --- .../Browse3/pages/CallPage/ObjectViewer.tsx | 8 +++---- .../Panel2/PanelTable/tableState.ts | 23 ++++++++++++++++--- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx index a21e47e3917..1a918d22541 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx @@ -73,7 +73,7 @@ type RefValues = Record; // ref URI to value type TruncatedStore = {[key: string]: {values: any; index: number}}; -const RESOVLED_REF_KEY = '_ref'; +const RESOLVED_REF_KEY = '_ref'; export const ARRAY_TRUNCATION_LENGTH = 50; const TRUNCATION_KEY = '__weave_array_truncated__'; @@ -163,13 +163,13 @@ export const ObjectViewer = ({ if (typeof val === 'object' && val !== null) { val = { ...v, - [RESOVLED_REF_KEY]: r, + [RESOLVED_REF_KEY]: r, }; } else { // This makes it so that runs pointing to primitives can still be expanded in the table. val = { '': v, - [RESOVLED_REF_KEY]: r, + [RESOLVED_REF_KEY]: r, }; } } @@ -182,7 +182,7 @@ export const ObjectViewer = ({ isWeaveRef(context.value) && refValues[context.value] != null && // Don't expand _ref keys - context.path.tail() !== RESOVLED_REF_KEY + context.path.tail() !== RESOLVED_REF_KEY ) { dirty = true; return refValues[context.value]; diff --git a/weave-js/src/components/Panel2/PanelTable/tableState.ts b/weave-js/src/components/Panel2/PanelTable/tableState.ts index 565fbb4b7f3..4cb8d286798 100644 --- a/weave-js/src/components/Panel2/PanelTable/tableState.ts +++ b/weave-js/src/components/Panel2/PanelTable/tableState.ts @@ -1,5 +1,6 @@ import { allObjPaths, + canSortType, constFunction, ConstNode, constNodeUnsafe, @@ -687,6 +688,16 @@ export async function disableGroupByCol( ) { const colIds = _.isArray(colId) ? colId : [colId]; const groupBy = ts.groupBy; + + // (WB-16067) + // We may try to sort on aggregated columns after ungrouping + // To prevent this, disable sorting on all the columns and re-enable + // after the ungroup + const initiallySortedCols = _.clone(ts.sort); + ts.sort.forEach(sortObj => { + ts = disableSortByCol(ts, sortObj.columnId); + }); + ts = produce(ts, draft => { draft.autoColumns = false; for (const cid of colIds) { @@ -701,9 +712,15 @@ export async function disableGroupByCol( } }); ts = await refreshSelectFunctions(ts, inputArrayNode, weave, stack); - if (ts.sort.find(s => s.columnId === colId) !== undefined) { - ts = disableSortByCol(ts, colId); - } + + initiallySortedCols.forEach(sortObj => { + if ( + sortObj.columnId !== colId && + canSortType(ts.columnSelectFunctions[sortObj.columnId].type) + ) { + ts = enableSortByCol(ts, sortObj.columnId, sortObj.dir === 'asc'); + } + }); return ts; } From 8c51b9336449bd7be4c28622733a4cd43b87d170 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 20:21:26 -0800 Subject: [PATCH 048/120] Remove online scorers page --- .../OpVersionPage/OpOnlineScorersTab.tsx | 289 ------------------ .../pages/OpVersionPage/OpVersionPage.tsx | 5 - .../traceServerDirectClient.ts | 14 +- 3 files changed, 7 insertions(+), 301 deletions(-) delete mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpOnlineScorersTab.tsx diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpOnlineScorersTab.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpOnlineScorersTab.tsx deleted file mode 100644 index 49ad30e69e4..00000000000 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpOnlineScorersTab.tsx +++ /dev/null @@ -1,289 +0,0 @@ -import {Box} from '@material-ui/core'; -import {Button} from '@wandb/weave/components/Button/Button'; -import {parseRef} from '@wandb/weave/react'; -import React, {FC, useMemo, useState} from 'react'; -import {z} from 'zod'; - -import {SmallRef} from '../../../Browse2/SmallRef'; -import { - ActionDispatchFilterSchema, - ActionDispatchFilterType, -} from '../../collections/actionCollection'; -import {collectionRegistry} from '../../collections/collectionRegistry'; -import { - useCollectionObjects, - useCreateCollectionObject, -} from '../../collections/getCollectionObjects'; -import {DynamicConfigForm} from '../../DynamicConfigForm'; -import {ReusableDrawer} from '../../ReusableDrawer'; -import {StyledDataGrid} from '../../StyledDataGrid'; -import {TraceObjSchema} from '../wfReactInterface/traceServerClientTypes'; -import { - convertISOToDate, - projectIdFromParts, -} from '../wfReactInterface/tsDataModelHooks'; -import {objectVersionKeyToRefUri} from '../wfReactInterface/utilities'; -import {OpVersionSchema} from '../wfReactInterface/wfDataModelHooksInterface'; - -type OnlineScorerType = TraceObjSchema; - -const useOnlineScorersForOpVersion = ( - opVersion: OpVersionSchema -): {scorers: OnlineScorerType[]; refresh: () => void} => { - const [poorMansRefreshCount, setPoorMansRefreshCount] = useState(0); - const req = useMemo(() => { - return { - project_id: projectIdFromParts({ - entity: opVersion.entity, - project: opVersion.project, - }), - filter: { - latest_only: true, - }, - poorMansRefreshCount, - }; - }, [opVersion.entity, opVersion.project, poorMansRefreshCount]); - const scorers = useCollectionObjects('ActionDispatchFilter', req).sort( - (a, b) => { - return ( - convertISOToDate(a.created_at).getTime() - - convertISOToDate(b.created_at).getTime() - ); - } - ); - const refresh = () => { - setPoorMansRefreshCount(poorMansRefreshCount + 1); - }; - return {scorers, refresh}; -}; - -export const OpOnlineScorersTab: React.FC<{ - opVersion: OpVersionSchema; -}> = ({opVersion}) => { - const req = useMemo(() => { - return { - project_id: projectIdFromParts({ - entity: opVersion.entity, - project: opVersion.project, - }), - filter: { - latest_only: true, - }, - }; - }, [opVersion.entity, opVersion.project]); - const availableActions = useCollectionObjects('ConfiguredAction', req); - const [isModalOpen, setIsModalOpen] = useState(false); - const {scorers: onlineScorers, refresh} = - useOnlineScorersForOpVersion(opVersion); - - const handleOpenModal = () => { - setIsModalOpen(true); - }; - - const handleCloseModal = (didSave: boolean) => { - refresh(); - setIsModalOpen(false); - }; - - const columns = [ - { - field: 'name', - headerName: 'Name', - flex: 1, - renderCell: (params: any) => { - return ; - }, - }, - {field: 'sampleRate', headerName: 'Sample Rate', flex: 1}, - { - field: 'configuredActionRef', - headerName: 'Configured Action', - flex: 1, - renderCell: (params: any) => { - return ; - }, - }, - {field: 'disabled', headerName: 'Disabled', flex: 1}, - ]; - - const rows = onlineScorers - .filter(scorer => scorer.val.op_name === opVersion.opId) - .map((scorer, index) => { - const scorerRef = objectVersionKeyToRefUri({ - scheme: 'weave', - weaveKind: 'object', - entity: opVersion.entity, - project: opVersion.project, - objectId: scorer.object_id, - versionHash: scorer.digest, - path: '', - }); - return { - id: scorerRef, - name: scorerRef, - createdAt: convertISOToDate(scorer.created_at), - disabled: scorer.val.disabled, - sampleRate: scorer.val.sample_rate, - configuredActionRef: scorer.val.configured_action_ref, - // Map other fields as needed - }; - }); - - const actionRefs = useMemo(() => { - return availableActions.map(action => { - return objectVersionKeyToRefUri({ - scheme: 'weave', - weaveKind: 'object', - entity: opVersion.entity, - project: opVersion.project, - objectId: action.object_id, - versionHash: action.digest, - path: '', - }); - }); - }, [availableActions, opVersion.entity, opVersion.project]); - - const inputSchema = useMemo(() => { - const base = ActionDispatchFilterSchema.merge( - z.object({ - op_name: z.literal(opVersion.opId), - }) - ); - if (actionRefs.length === 0) { - return base; - } - - return base.merge( - z.object({ - configured_action_ref: z.enum( - actionRefs as unknown as [string, ...string[]] - ), - }) - ); - }, [actionRefs, opVersion.opId]); - - return ( - - - - - - - - ); -}; - -interface NewOnlineOpScorerModalProps { - entity: string; - project: string; - collectionDef: { - name: keyof typeof collectionRegistry; - schema: z.Schema; - }; - isOpen: boolean; - onClose: (didSave: boolean) => void; -} - -export const NewOnlineOpScorerModal: FC = ({ - entity, - project, - collectionDef, - isOpen, - onClose, -}) => { - const [config, setConfig] = useState>({}); - - const createCollectionObject = useCreateCollectionObject(collectionDef.name); - - const handleSaveModal = (newAction: Record) => { - const parsedAction = collectionDef.schema.safeParse(newAction); - if (!parsedAction.success) { - console.error( - `Invalid action: ${JSON.stringify(parsedAction.error.errors)}` - ); - return; - } - const opName = parsedAction.data.op_name; - const actionRef = parsedAction.data.configured_action_ref; - const actionName = parseRef(actionRef).artifactName; - let objectId = `${opName}-${actionName}`; - // Remove non alphanumeric characters - objectId = objectId.replace(/[^a-zA-Z0-9]/g, '-'); - createCollectionObject({ - obj: { - project_id: projectIdFromParts({entity, project}), - object_id: objectId, - val: parsedAction.data, - }, - }) - .catch(err => { - console.error(err); - }) - .finally(() => { - setConfig({}); - onClose(true); - }); - }; - - const [isValid, setIsValid] = useState(false); - - return ( - onClose(false)} - onSave={() => handleSaveModal(config)} - saveDisabled={!isValid}> - - - ); -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx index 530ac5e995b..9e14df7152c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx @@ -20,7 +20,6 @@ import {TabUseOp} from '../TabUseOp'; import {useWFHooks} from '../wfReactInterface/context'; import {opVersionKeyToRefUri} from '../wfReactInterface/utilities'; import {OpVersionSchema} from '../wfReactInterface/wfDataModelHooksInterface'; -import {OpOnlineScorersTab} from './OpOnlineScorersTab'; export const OpVersionPage: React.FC<{ entity: string; @@ -133,10 +132,6 @@ const OpVersionPageInner: React.FC<{ /> ), }, - { - label: 'Online Scorers', - content: , - }, ...(useOpSupported ? [ { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts index b4bc850edbb..5c269aa378a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts @@ -228,6 +228,13 @@ export class DirectTraceServerClient { return this.makeRequest('/obj/read', req); } + public readBatch(req: TraceRefsReadBatchReq): Promise { + return this.makeRequest( + '/refs/read_batch', + req + ); + } + public objCreate(req: TraceObjCreateReq): Promise { const initialObjectId = req.obj.object_id; const sanitizedObjectId = sanitizeObjectId(initialObjectId); @@ -244,13 +251,6 @@ export class DirectTraceServerClient { ); } - public readBatch(req: TraceRefsReadBatchReq): Promise { - return this.makeRequest( - '/refs/read_batch', - req - ); - } - public tableQuery(req: TraceTableQueryReq): Promise { return this.makeRequest( '/table/query', From ad73416b6c015c2cf4d6d3b47b56b33c5b0f61c7 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 20:30:10 -0800 Subject: [PATCH 049/120] Changed name from metrics to scorers --- .../PagePanelComponents/Home/Browse3/context.tsx | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/context.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/context.tsx index fddcd4fbda4..03e7d462cb0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/context.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/context.tsx @@ -194,7 +194,7 @@ export const browse2Context = { throw new Error('Not implemented'); }, - metricsUIUrl: (entityName: string, projectName: string) => { + scorersUIUrl: (entityName: string, projectName: string) => { throw new Error('Not implemented'); }, leaderboardsUIUrl: ( @@ -435,8 +435,8 @@ export const browse3ContextGen = ( )}${metricsPart}`; }, - metricsUIUrl: (entityName: string, projectName: string) => { - return `${projectRoot(entityName, projectName)}/metrics`; + scorersUIUrl: (entityName: string, projectName: string) => { + return `${projectRoot(entityName, projectName)}/scorers`; }, leaderboardsUIUrl: ( @@ -534,7 +534,7 @@ type RouteType = { metrics: Record | null ) => string; - metricsUIUrl: (entityName: string, projectName: string) => string; + scorersUIUrl: (entityName: string, projectName: string) => string; leaderboardsUIUrl: ( entityName: string, @@ -655,8 +655,8 @@ const useMakePeekingRouter = (): RouteType => { baseContext.compareEvaluationsUri(...args) ); }, - metricsUIUrl: (...args: Parameters) => { - return setSearchParam(PEEK_PARAM, baseContext.metricsUIUrl(...args)); + scorersUIUrl: (...args: Parameters) => { + return setSearchParam(PEEK_PARAM, baseContext.scorersUIUrl(...args)); }, leaderboardsUIUrl: ( ...args: Parameters From 10541ad2ecd02a6cac7d3e7b7702c59a9e088a36 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 20:32:24 -0800 Subject: [PATCH 050/120] Fixed interface --- .../Browse3/pages/wfReactInterface/traceServerClientTypes.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts index 21565df0533..bcc68300a4f 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts @@ -309,8 +309,8 @@ export const fileExtensions = { export type ActionsExecuteBatchReq = { project_id: string; + action_ref: string; call_ids: string[]; - configured_action_ref: string; }; export type ActionsExecuteBatchRes = {}; From 0a68915122b4db58190eed79ce5092f73d3269ce Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 20:33:14 -0800 Subject: [PATCH 051/120] Changed ConfiguredAction to AxctionDefinition --- .../Browse3/collections/actionCollection.ts | 4 ++-- .../Browse3/collections/collectionRegistry.ts | 4 ++-- .../pages/CallPage/CallActionsViewer.tsx | 22 +++++++++---------- .../Home/Browse3/pages/ObjectVersionPage.tsx | 2 +- .../NewBuiltInActionScorerModal.tsx | 8 +++---- .../Browse3/pages/ScorersPage/ScorersPage.tsx | 8 +++---- .../pages/common/TypeVersionCategoryChip.tsx | 2 +- .../pages/wfReactInterface/constants.ts | 2 +- 8 files changed, 26 insertions(+), 26 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts index c79c820b2a4..6ed701e87fa 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts @@ -34,11 +34,11 @@ export const ActionConfigSchema = z.discriminatedUnion('action_type', [ ]); export type ActionConfigType = z.infer; -export const ConfiguredActionSchema = z.object({ +export const AxctionDefinitionSchema = z.object({ name: z.string(), config: ActionConfigSchema, }); -export type ConfiguredActionType = z.infer; +export type AxctionDefinitionType = z.infer; export const ActionDispatchFilterSchema = z.object({ op_name: z.string(), diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts index 8698054e3d4..577416f8b6f 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts @@ -1,9 +1,9 @@ import { ActionDispatchFilterSchema, - ConfiguredActionSchema, + AxctionDefinitionSchema, } from './actionCollection'; export const collectionRegistry = { - ConfiguredAction: ConfiguredActionSchema, + AxctionDefinition: AxctionDefinitionSchema, ActionDispatchFilter: ActionDispatchFilterSchema, }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx index 520bced46de..846415dfcbc 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx @@ -7,7 +7,7 @@ import {z} from 'zod'; import {CellValue} from '../../../Browse2/CellValue'; import {NotApplicable} from '../../../Browse2/NotApplicable'; -import {ConfiguredActionType} from '../../collections/actionCollection'; +import {AxctionDefinitionType} from '../../collections/actionCollection'; import {useCollectionObjects} from '../../collections/getCollectionObjects'; import {StyledDataGrid} from '../../StyledDataGrid'; // Import the StyledDataGrid component import {WEAVE_REF_SCHEME} from '../wfReactInterface/constants'; @@ -23,7 +23,7 @@ import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; type CallActionRow = { actionRef: string; - actionDef: ConfiguredActionType; + actionDef: AxctionDefinitionType; runCount: number; lastResult?: unknown; lastRanAt?: Date; @@ -89,7 +89,7 @@ export const CallActionsViewer: React.FC<{ weaveRef, }); - const configuredActions = useCollectionObjects('ConfiguredAction', { + const AxctionDefinitions = useCollectionObjects('AxctionDefinition', { project_id: projectIdFromParts({ entity: props.call.entity, project: props.call.project, @@ -130,22 +130,22 @@ export const CallActionsViewer: React.FC<{ const allCallActions: CallActionRow[] = useMemo(() => { return ( - configuredActions?.map(configuredAction => { - const configuredActionRefUri = objectVersionKeyToRefUri({ + AxctionDefinitions?.map(AxctionDefinition => { + const AxctionDefinitionRefUri = objectVersionKeyToRefUri({ scheme: WEAVE_REF_SCHEME, weaveKind: 'object', entity: props.call.entity, project: props.call.project, - objectId: configuredAction.object_id, - versionHash: configuredAction.digest, + objectId: AxctionDefinition.object_id, + versionHash: AxctionDefinition.digest, path: '', }); - const feedbacks = getFeedbackForAction(configuredActionRefUri); + const feedbacks = getFeedbackForAction(AxctionDefinitionRefUri); const selectedFeedback = feedbacks.length > 0 ? feedbacks[0] : undefined; return { - actionRef: configuredActionRefUri, - actionDef: configuredAction.val, + actionRef: AxctionDefinitionRefUri, + actionDef: AxctionDefinition.val, runCount: feedbacks.length, lastRanAt: selectedFeedback ? convertISOToDate(selectedFeedback.feedbackRaw.created_at + 'Z') @@ -157,7 +157,7 @@ export const CallActionsViewer: React.FC<{ }) ?? [] ); }, [ - configuredActions, + AxctionDefinitions, getFeedbackForAction, props.call.entity, props.call.project, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx index 4b95e9b2ca0..fca2cac7aa0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx @@ -54,7 +54,7 @@ const OBJECT_ICONS: Record = { Evaluation: 'baseline-alt', Leaderboard: 'benchmark-square', Scorer: 'type-number-alt', - ConfiguredAction: 'rocket-launch', + AxctionDefinition: 'rocket-launch', }; const ObjectIcon = ({baseObjectClass}: ObjectIconProps) => { if (baseObjectClass in OBJECT_ICONS) { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx index f332e36d77c..23736bd3504 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx @@ -10,8 +10,8 @@ import React, {FC, useEffect, useState} from 'react'; import {z} from 'zod'; import { - ConfiguredActionSchema, - ConfiguredActionType, + AxctionDefinitionSchema, + AxctionDefinitionType, ConfiguredLlmJudgeActionSchema, ConfiguredWordCountActionSchema, } from '../../collections/actionCollection'; @@ -73,7 +73,7 @@ const knownBuiltinActions = [ interface NewBuiltInActionScorerModalProps { open: boolean; onClose: () => void; - onSave: (newAction: ConfiguredActionType) => void; + onSave: (newAction: AxctionDefinitionType) => void; initialTemplate: string; } @@ -98,7 +98,7 @@ export const NewBuiltInActionScorerModal: FC< }, [initialTemplate]); const handleSave = () => { - const newAction = ConfiguredActionSchema.parse({ + const newAction = AxctionDefinitionSchema.parse({ name, config: knownBuiltinActions[selectedActionIndex].friendly.convert( config as any diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx index 585d0c7e47f..35641ab62cb 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx @@ -5,7 +5,7 @@ import MenuItem from '@mui/material/MenuItem'; import {Button} from '@wandb/weave/components/Button/Button'; import React, {FC, useState} from 'react'; -import {ConfiguredActionType} from '../../collections/actionCollection'; +import {AxctionDefinitionType} from '../../collections/actionCollection'; import {useCreateCollectionObject} from '../../collections/getCollectionObjects'; import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; import {FilterableObjectVersionsTable} from '../ObjectVersionsPage'; @@ -81,7 +81,7 @@ const OnlineScorersTab: React.FC<{ }> = ({entity, project}) => { const [isModalOpen, setIsModalOpen] = useState(false); const [selectedTemplate, setSelectedTemplate] = useState(''); - const createCollectionObject = useCreateCollectionObject('ConfiguredAction'); + const createCollectionObject = useCreateCollectionObject('AxctionDefinition'); const [lastUpdatedTimestamp, setLastUpdatedTimestamp] = useState(0); const [anchorEl, setAnchorEl] = React.useState(null); @@ -111,7 +111,7 @@ const OnlineScorersTab: React.FC<{ setSelectedTemplate(''); }; - const handleSaveModal = (newAction: ConfiguredActionType) => { + const handleSaveModal = (newAction: AxctionDefinitionType) => { let objectId = newAction.name; // Remove non alphanumeric characters objectId = objectId.replace(/[^a-zA-Z0-9]/g, '-'); @@ -181,7 +181,7 @@ const OnlineScorersTab: React.FC<{ entity={entity} project={project} initialFilter={{ - baseObjectClass: 'ConfiguredAction', + baseObjectClass: 'AxctionDefinition', }} /> = { Evaluation: 'cactus', Leaderboard: 'gold', Scorer: 'purple', - ConfiguredAction: 'sienna', + AxctionDefinition: 'sienna', }; export const TypeVersionCategoryChip: React.FC<{ diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts index f9a9b5cda8d..31d93be42ba 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts @@ -26,5 +26,5 @@ export const KNOWN_BASE_OBJECT_CLASSES = [ 'Evaluation', 'Leaderboard', 'Scorer', - 'ConfiguredAction', + 'AxctionDefinition', ] as const; From 0b2152dda9abc4d3d7ec26851f730fcc1413cfac Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 20:33:54 -0800 Subject: [PATCH 052/120] Changed ConfiguredAction to AxctionDefinition --- .../Browse3/collections/actionCollection.ts | 4 ++-- .../Browse3/collections/collectionRegistry.ts | 4 ++-- .../pages/CallPage/CallActionsViewer.tsx | 22 +++++++++---------- .../Home/Browse3/pages/ObjectVersionPage.tsx | 2 +- .../NewBuiltInActionScorerModal.tsx | 8 +++---- .../Browse3/pages/ScorersPage/ScorersPage.tsx | 8 +++---- .../pages/common/TypeVersionCategoryChip.tsx | 2 +- .../pages/wfReactInterface/constants.ts | 2 +- 8 files changed, 26 insertions(+), 26 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts index 6ed701e87fa..21f129d7c83 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts @@ -34,11 +34,11 @@ export const ActionConfigSchema = z.discriminatedUnion('action_type', [ ]); export type ActionConfigType = z.infer; -export const AxctionDefinitionSchema = z.object({ +export const ActionDefinitionSchema = z.object({ name: z.string(), config: ActionConfigSchema, }); -export type AxctionDefinitionType = z.infer; +export type ActionDefinitionType = z.infer; export const ActionDispatchFilterSchema = z.object({ op_name: z.string(), diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts index 577416f8b6f..ad5cbf4e070 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts @@ -1,9 +1,9 @@ import { ActionDispatchFilterSchema, - AxctionDefinitionSchema, + ActionDefinitionSchema, } from './actionCollection'; export const collectionRegistry = { - AxctionDefinition: AxctionDefinitionSchema, + ActionDefinition: ActionDefinitionSchema, ActionDispatchFilter: ActionDispatchFilterSchema, }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx index 846415dfcbc..e2bc6a2f812 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx @@ -7,7 +7,7 @@ import {z} from 'zod'; import {CellValue} from '../../../Browse2/CellValue'; import {NotApplicable} from '../../../Browse2/NotApplicable'; -import {AxctionDefinitionType} from '../../collections/actionCollection'; +import {ActionDefinitionType} from '../../collections/actionCollection'; import {useCollectionObjects} from '../../collections/getCollectionObjects'; import {StyledDataGrid} from '../../StyledDataGrid'; // Import the StyledDataGrid component import {WEAVE_REF_SCHEME} from '../wfReactInterface/constants'; @@ -23,7 +23,7 @@ import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; type CallActionRow = { actionRef: string; - actionDef: AxctionDefinitionType; + actionDef: ActionDefinitionType; runCount: number; lastResult?: unknown; lastRanAt?: Date; @@ -89,7 +89,7 @@ export const CallActionsViewer: React.FC<{ weaveRef, }); - const AxctionDefinitions = useCollectionObjects('AxctionDefinition', { + const ActionDefinitions = useCollectionObjects('ActionDefinition', { project_id: projectIdFromParts({ entity: props.call.entity, project: props.call.project, @@ -130,22 +130,22 @@ export const CallActionsViewer: React.FC<{ const allCallActions: CallActionRow[] = useMemo(() => { return ( - AxctionDefinitions?.map(AxctionDefinition => { - const AxctionDefinitionRefUri = objectVersionKeyToRefUri({ + ActionDefinitions?.map(ActionDefinition => { + const ActionDefinitionRefUri = objectVersionKeyToRefUri({ scheme: WEAVE_REF_SCHEME, weaveKind: 'object', entity: props.call.entity, project: props.call.project, - objectId: AxctionDefinition.object_id, - versionHash: AxctionDefinition.digest, + objectId: ActionDefinition.object_id, + versionHash: ActionDefinition.digest, path: '', }); - const feedbacks = getFeedbackForAction(AxctionDefinitionRefUri); + const feedbacks = getFeedbackForAction(ActionDefinitionRefUri); const selectedFeedback = feedbacks.length > 0 ? feedbacks[0] : undefined; return { - actionRef: AxctionDefinitionRefUri, - actionDef: AxctionDefinition.val, + actionRef: ActionDefinitionRefUri, + actionDef: ActionDefinition.val, runCount: feedbacks.length, lastRanAt: selectedFeedback ? convertISOToDate(selectedFeedback.feedbackRaw.created_at + 'Z') @@ -157,7 +157,7 @@ export const CallActionsViewer: React.FC<{ }) ?? [] ); }, [ - AxctionDefinitions, + ActionDefinitions, getFeedbackForAction, props.call.entity, props.call.project, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx index fca2cac7aa0..f60cf7c7007 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx @@ -54,7 +54,7 @@ const OBJECT_ICONS: Record = { Evaluation: 'baseline-alt', Leaderboard: 'benchmark-square', Scorer: 'type-number-alt', - AxctionDefinition: 'rocket-launch', + ActionDefinition: 'rocket-launch', }; const ObjectIcon = ({baseObjectClass}: ObjectIconProps) => { if (baseObjectClass in OBJECT_ICONS) { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx index 23736bd3504..4a2234c8fb9 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx @@ -10,8 +10,8 @@ import React, {FC, useEffect, useState} from 'react'; import {z} from 'zod'; import { - AxctionDefinitionSchema, - AxctionDefinitionType, + ActionDefinitionSchema, + ActionDefinitionType, ConfiguredLlmJudgeActionSchema, ConfiguredWordCountActionSchema, } from '../../collections/actionCollection'; @@ -73,7 +73,7 @@ const knownBuiltinActions = [ interface NewBuiltInActionScorerModalProps { open: boolean; onClose: () => void; - onSave: (newAction: AxctionDefinitionType) => void; + onSave: (newAction: ActionDefinitionType) => void; initialTemplate: string; } @@ -98,7 +98,7 @@ export const NewBuiltInActionScorerModal: FC< }, [initialTemplate]); const handleSave = () => { - const newAction = AxctionDefinitionSchema.parse({ + const newAction = ActionDefinitionSchema.parse({ name, config: knownBuiltinActions[selectedActionIndex].friendly.convert( config as any diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx index 35641ab62cb..25fa6cc4f8c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx @@ -5,7 +5,7 @@ import MenuItem from '@mui/material/MenuItem'; import {Button} from '@wandb/weave/components/Button/Button'; import React, {FC, useState} from 'react'; -import {AxctionDefinitionType} from '../../collections/actionCollection'; +import {ActionDefinitionType} from '../../collections/actionCollection'; import {useCreateCollectionObject} from '../../collections/getCollectionObjects'; import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; import {FilterableObjectVersionsTable} from '../ObjectVersionsPage'; @@ -81,7 +81,7 @@ const OnlineScorersTab: React.FC<{ }> = ({entity, project}) => { const [isModalOpen, setIsModalOpen] = useState(false); const [selectedTemplate, setSelectedTemplate] = useState(''); - const createCollectionObject = useCreateCollectionObject('AxctionDefinition'); + const createCollectionObject = useCreateCollectionObject('ActionDefinition'); const [lastUpdatedTimestamp, setLastUpdatedTimestamp] = useState(0); const [anchorEl, setAnchorEl] = React.useState(null); @@ -111,7 +111,7 @@ const OnlineScorersTab: React.FC<{ setSelectedTemplate(''); }; - const handleSaveModal = (newAction: AxctionDefinitionType) => { + const handleSaveModal = (newAction: ActionDefinitionType) => { let objectId = newAction.name; // Remove non alphanumeric characters objectId = objectId.replace(/[^a-zA-Z0-9]/g, '-'); @@ -181,7 +181,7 @@ const OnlineScorersTab: React.FC<{ entity={entity} project={project} initialFilter={{ - baseObjectClass: 'AxctionDefinition', + baseObjectClass: 'ActionDefinition', }} /> = { Evaluation: 'cactus', Leaderboard: 'gold', Scorer: 'purple', - AxctionDefinition: 'sienna', + ActionDefinition: 'sienna', }; export const TypeVersionCategoryChip: React.FC<{ diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts index 31d93be42ba..bf54110c369 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts @@ -26,5 +26,5 @@ export const KNOWN_BASE_OBJECT_CLASSES = [ 'Evaluation', 'Leaderboard', 'Scorer', - 'AxctionDefinition', + 'ActionDefinition', ] as const; From 7514cd849b2c840b3d9a41d6eada93b09de2cc07 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 20:35:01 -0800 Subject: [PATCH 053/120] Lint --- .../Home/Browse3/collections/collectionRegistry.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts index ad5cbf4e070..b38fe7d20d0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts @@ -1,6 +1,6 @@ import { - ActionDispatchFilterSchema, ActionDefinitionSchema, + ActionDispatchFilterSchema, } from './actionCollection'; export const collectionRegistry = { From 6279d441b278502c7a129c1f0441cf032ffe2d12 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 21:39:22 -0800 Subject: [PATCH 054/120] Fixed zod --- .../generatedBaseObjectClasses.zod.ts | 35 ++++++ .../generated_base_object_class_schemas.json | 118 +++++++++++++++++- 2 files changed, 152 insertions(+), 1 deletion(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBaseObjectClasses.zod.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBaseObjectClasses.zod.ts index 1acf71ad314..10ba7842599 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBaseObjectClasses.zod.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBaseObjectClasses.zod.ts @@ -1,5 +1,32 @@ import * as z from 'zod'; +// BEGINNING OF CUSTOM CODE ///// +// Sadly, the json-schema to zod converter doesn't support discriminator +// so we have to define the schemas manually. If you run the generator +// make sure to review the changes to this section. +export const LlmJudgeActionSpecSchema = z.object({ + action_type: z.enum(['llm_judge']), + model: z.enum(['gpt-4o', 'gpt-4o-mini']), + prompt: z.string(), + response_schema: z.record(z.string(), z.any()), +}); +export type LlmJudgeActionSpec = z.infer; + +export const ContainsWordsActionSpecSchema = z.object({ + action_type: z.enum(['contains_words']), + target_words: z.array(z.string()), +}); +export type ContainsWordsActionSpec = z.infer< + typeof ContainsWordsActionSpecSchema +>; + +export const SpecSchema = z.discriminatedUnion('action_type', [ + LlmJudgeActionSpecSchema, + ContainsWordsActionSpecSchema, +]); +export type Spec = z.infer; +// END OF CUSTOM CODE ///// + export const LeaderboardColumnSchema = z.object({ evaluation_object_ref: z.string(), scorer_name: z.string(), @@ -24,6 +51,13 @@ export type TestOnlyNestedBaseObject = z.infer< typeof TestOnlyNestedBaseObjectSchema >; +export const ActionDefinitionSchema = z.object({ + description: z.union([z.null(), z.string()]).optional(), + name: z.union([z.null(), z.string()]).optional(), + spec: SpecSchema, +}); +export type ActionDefinition = z.infer; + export const LeaderboardSchema = z.object({ columns: z.array(LeaderboardColumnSchema), description: z.union([z.null(), z.string()]).optional(), @@ -41,6 +75,7 @@ export const TestOnlyExampleSchema = z.object({ export type TestOnlyExample = z.infer; export const baseObjectClassRegistry = { + ActionDefinition: ActionDefinitionSchema, Leaderboard: LeaderboardSchema, TestOnlyExample: TestOnlyExampleSchema, TestOnlyNestedBaseObject: TestOnlyNestedBaseObjectSchema, diff --git a/weave/trace_server/interface/base_object_classes/generated/generated_base_object_class_schemas.json b/weave/trace_server/interface/base_object_classes/generated/generated_base_object_class_schemas.json index 2207d2c1f3c..15b3660cf2b 100644 --- a/weave/trace_server/interface/base_object_classes/generated/generated_base_object_class_schemas.json +++ b/weave/trace_server/interface/base_object_classes/generated/generated_base_object_class_schemas.json @@ -1,5 +1,81 @@ { "$defs": { + "ActionDefinition": { + "properties": { + "name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Name" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Description" + }, + "spec": { + "discriminator": { + "mapping": { + "contains_words": "#/$defs/ContainsWordsActionSpec", + "llm_judge": "#/$defs/LlmJudgeActionSpec" + }, + "propertyName": "action_type" + }, + "oneOf": [ + { + "$ref": "#/$defs/LlmJudgeActionSpec" + }, + { + "$ref": "#/$defs/ContainsWordsActionSpec" + } + ], + "title": "Spec" + } + }, + "required": [ + "spec" + ], + "title": "ActionDefinition", + "type": "object" + }, + "ContainsWordsActionSpec": { + "properties": { + "action_type": { + "const": "contains_words", + "default": "contains_words", + "enum": [ + "contains_words" + ], + "title": "Action Type", + "type": "string" + }, + "target_words": { + "items": { + "type": "string" + }, + "title": "Target Words", + "type": "array" + } + }, + "required": [ + "target_words" + ], + "title": "ContainsWordsActionSpec", + "type": "object" + }, "Leaderboard": { "properties": { "name": { @@ -75,6 +151,42 @@ "title": "LeaderboardColumn", "type": "object" }, + "LlmJudgeActionSpec": { + "properties": { + "action_type": { + "const": "llm_judge", + "default": "llm_judge", + "enum": [ + "llm_judge" + ], + "title": "Action Type", + "type": "string" + }, + "model": { + "enum": [ + "gpt-4o", + "gpt-4o-mini" + ], + "title": "Model", + "type": "string" + }, + "prompt": { + "title": "Prompt", + "type": "string" + }, + "response_schema": { + "title": "Response Schema", + "type": "object" + } + }, + "required": [ + "model", + "prompt", + "response_schema" + ], + "title": "LlmJudgeActionSpec", + "type": "object" + }, "TestOnlyExample": { "properties": { "name": { @@ -181,12 +293,16 @@ }, "Leaderboard": { "$ref": "#/$defs/Leaderboard" + }, + "ActionDefinition": { + "$ref": "#/$defs/ActionDefinition" } }, "required": [ "TestOnlyExample", "TestOnlyNestedBaseObject", - "Leaderboard" + "Leaderboard", + "ActionDefinition" ], "title": "CompositeBaseObject", "type": "object" From 9aec32a78bcb94a42bf39e22e08cb77fc87a1edb Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 21:40:18 -0800 Subject: [PATCH 055/120] Fixed discriminator --- weave/trace_server/interface/base_object_classes/actions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/weave/trace_server/interface/base_object_classes/actions.py b/weave/trace_server/interface/base_object_classes/actions.py index 8ef94402945..4262fe5d7ce 100644 --- a/weave/trace_server/interface/base_object_classes/actions.py +++ b/weave/trace_server/interface/base_object_classes/actions.py @@ -1,6 +1,6 @@ from typing import Any, Literal, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field from weave.trace_server.interface.base_object_classes import base_object_def @@ -30,4 +30,4 @@ class ContainsWordsActionSpec(BaseModel): class ActionDefinition(base_object_def.BaseObject): # Pyright doesn't like this override # name: str - spec: ActionSpecType + spec: ActionSpecType = Field(..., discriminator="action_type") From a45026bddb13cf7e1e463bd860e63d63d9fc72a1 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 22:02:28 -0800 Subject: [PATCH 056/120] Migrated to new data model --- .../Browse3/collections/actionCollection.ts | 50 ++++---- .../Browse3/collections/collectionRegistry.ts | 9 -- .../collections/getCollectionObjects.tsx | 118 ------------------ .../pages/CallPage/CallActionsViewer.tsx | 39 +++--- .../NewBuiltInActionScorerModal.tsx | 41 +++--- .../Browse3/pages/ScorersPage/ScorersPage.tsx | 12 +- 6 files changed, 73 insertions(+), 196 deletions(-) delete mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts delete mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/collections/getCollectionObjects.tsx diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts index 21f129d7c83..6a288aa0ca0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts @@ -21,31 +21,31 @@ export type ConfiguredLlmJudgeActionType = z.infer< typeof ConfiguredLlmJudgeActionSchema >; -export const ConfiguredWordCountActionSchema = z.object({ - action_type: z.literal('wordcount'), -}); -export type ConfiguredWordCountActionType = z.infer< - typeof ConfiguredWordCountActionSchema ->; +// export const ConfiguredWordCountActionSchema = z.object({ +// action_type: z.literal('wordcount'), +// }); +// export type ConfiguredWordCountActionType = z.infer< +// typeof ConfiguredWordCountActionSchema +// >; -export const ActionConfigSchema = z.discriminatedUnion('action_type', [ - ConfiguredLlmJudgeActionSchema, - ConfiguredWordCountActionSchema, -]); -export type ActionConfigType = z.infer; +// export const ActionConfigSchema = z.discriminatedUnion('action_type', [ +// ConfiguredLlmJudgeActionSchema, +// ConfiguredWordCountActionSchema, +// ]); +// export type ActionConfigType = z.infer; -export const ActionDefinitionSchema = z.object({ - name: z.string(), - config: ActionConfigSchema, -}); -export type ActionDefinitionType = z.infer; +// export const ActionDefinitionSchema = z.object({ +// name: z.string(), +// config: ActionConfigSchema, +// }); +// export type ActionDefinitionType = z.infer; -export const ActionDispatchFilterSchema = z.object({ - op_name: z.string(), - sample_rate: z.number().min(0).max(1).default(1), - configured_action_ref: z.string(), - disabled: z.boolean().optional(), -}); -export type ActionDispatchFilterType = z.infer< - typeof ActionDispatchFilterSchema ->; +// export const ActionDispatchFilterSchema = z.object({ +// op_name: z.string(), +// sample_rate: z.number().min(0).max(1).default(1), +// configured_action_ref: z.string(), +// disabled: z.boolean().optional(), +// }); +// export type ActionDispatchFilterType = z.infer< +// typeof ActionDispatchFilterSchema +// >; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts deleted file mode 100644 index b38fe7d20d0..00000000000 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/collectionRegistry.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { - ActionDefinitionSchema, - ActionDispatchFilterSchema, -} from './actionCollection'; - -export const collectionRegistry = { - ActionDefinition: ActionDefinitionSchema, - ActionDispatchFilter: ActionDispatchFilterSchema, -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/getCollectionObjects.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/getCollectionObjects.tsx deleted file mode 100644 index 1f3707aff05..00000000000 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/getCollectionObjects.tsx +++ /dev/null @@ -1,118 +0,0 @@ -import {useDeepMemo} from '@wandb/weave/hookUtils'; -import {useEffect, useState} from 'react'; -import {z} from 'zod'; - -import {TraceServerClient} from '../pages/wfReactInterface/traceServerClient'; -import {useGetTraceServerClientContext} from '../pages/wfReactInterface/traceServerClientContext'; -import { - TraceObjCreateReq, - TraceObjQueryReq, - TraceObjSchema, -} from '../pages/wfReactInterface/traceServerClientTypes'; -import {collectionRegistry} from './collectionRegistry'; - -export const useCollectionObjects = < - C extends keyof typeof collectionRegistry, - T extends z.infer<(typeof collectionRegistry)[C]> ->( - collectionName: C, - req: TraceObjQueryReq -) => { - const [objects, setObjects] = useState>>([]); - const getTsClient = useGetTraceServerClientContext(); - const client = getTsClient(); - const deepReq = useDeepMemo(req); - - useEffect(() => { - let isMounted = true; - getCollectionObjects(client, collectionName, deepReq).then( - collectionObjects => { - if (isMounted) { - setObjects(collectionObjects as Array>); - } - } - ); - return () => { - isMounted = false; - }; - }, [client, collectionName, deepReq]); - - return objects; -}; - -const getCollectionObjects = async < - C extends keyof typeof collectionRegistry, - T extends z.infer<(typeof collectionRegistry)[C]> ->( - client: TraceServerClient, - collectionName: C, - req: TraceObjQueryReq -): Promise>> => { - const knownCollection = collectionRegistry[collectionName]; - if (!knownCollection) { - console.warn(`Unknown collection: ${collectionName}`); - return []; - } - - const reqWithCollection: TraceObjQueryReq = { - ...req, - filter: {...req.filter, base_object_classes: [collectionName]}, - }; - - const objectPromise = client.objsQuery(reqWithCollection); - - const objects = await objectPromise; - - return objects.objs - .map(obj => ({obj, parsed: knownCollection.safeParse(obj.val)})) - .filter(({parsed}) => parsed.success) - .map(({obj, parsed}) => ({...obj, val: parsed.data!})) as Array< - TraceObjSchema - >; -}; - -export const useCreateCollectionObject = < - C extends keyof typeof collectionRegistry, - T extends z.infer<(typeof collectionRegistry)[C]> ->( - collectionName: C -) => { - const getTsClient = useGetTraceServerClientContext(); - const client = getTsClient(); - return (req: TraceObjCreateReq) => - createCollectionObject(client, collectionName, req); -}; - -const createCollectionObject = async < - C extends keyof typeof collectionRegistry, - T extends z.infer<(typeof collectionRegistry)[C]> ->( - client: TraceServerClient, - collectionName: C, - req: TraceObjCreateReq -) => { - const knownCollection = collectionRegistry[collectionName]; - if (!knownCollection) { - throw new Error(`Unknown collection: ${collectionName}`); - } - - const verifiedObject = knownCollection.safeParse(req.obj.val); - - if (!verifiedObject.success) { - throw new Error( - `Invalid object: ${JSON.stringify(verifiedObject.error.errors)}` - ); - } - - const reqWithCollection: TraceObjCreateReq = { - ...req, - obj: { - ...req.obj, - val: {...req.obj.val, _bases: [collectionName, 'BaseModel']}, - }, - }; - - const createPromse = client.objCreate(reqWithCollection); - - return createPromse; -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx index e2bc6a2f812..18feb056667 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx @@ -7,11 +7,12 @@ import {z} from 'zod'; import {CellValue} from '../../../Browse2/CellValue'; import {NotApplicable} from '../../../Browse2/NotApplicable'; -import {ActionDefinitionType} from '../../collections/actionCollection'; -import {useCollectionObjects} from '../../collections/getCollectionObjects'; +// import {ActionDefinitionType} from '../../collections/actionCollection'; import {StyledDataGrid} from '../../StyledDataGrid'; // Import the StyledDataGrid component +import {useBaseObjectInstances} from '../wfReactInterface/baseObjectClassQuery'; import {WEAVE_REF_SCHEME} from '../wfReactInterface/constants'; import {useWFHooks} from '../wfReactInterface/context'; +import {ActionDefinition} from '../wfReactInterface/generatedBaseObjectClasses.zod'; import {useGetTraceServerClientContext} from '../wfReactInterface/traceServerClientContext'; import {Feedback} from '../wfReactInterface/traceServerClientTypes'; import { @@ -23,7 +24,7 @@ import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; type CallActionRow = { actionRef: string; - actionDef: ActionDefinitionType; + actionDef: ActionDefinition; runCount: number; lastResult?: unknown; lastRanAt?: Date; @@ -89,13 +90,15 @@ export const CallActionsViewer: React.FC<{ weaveRef, }); - const ActionDefinitions = useCollectionObjects('ActionDefinition', { - project_id: projectIdFromParts({ - entity: props.call.entity, - project: props.call.project, - }), - filter: {latest_only: true}, - }).sort((a, b) => a.val.name.localeCompare(b.val.name)); + const actionDefinitions = ( + useBaseObjectInstances('ActionDefinition', { + project_id: projectIdFromParts({ + entity: props.call.entity, + project: props.call.project, + }), + filter: {latest_only: true}, + }).result ?? [] + ).sort((a, b) => (a.val.name ?? '').localeCompare(b.val.name ?? '')); const verifiedActionFeedbacks: Array<{ data: MachineScoreFeedbackPayloadType; feedbackRaw: Feedback; @@ -130,22 +133,22 @@ export const CallActionsViewer: React.FC<{ const allCallActions: CallActionRow[] = useMemo(() => { return ( - ActionDefinitions?.map(ActionDefinition => { - const ActionDefinitionRefUri = objectVersionKeyToRefUri({ + actionDefinitions.map(actionDefinition => { + const actionDefinitionRefUri = objectVersionKeyToRefUri({ scheme: WEAVE_REF_SCHEME, weaveKind: 'object', entity: props.call.entity, project: props.call.project, - objectId: ActionDefinition.object_id, - versionHash: ActionDefinition.digest, + objectId: actionDefinition.object_id, + versionHash: actionDefinition.digest, path: '', }); - const feedbacks = getFeedbackForAction(ActionDefinitionRefUri); + const feedbacks = getFeedbackForAction(actionDefinitionRefUri); const selectedFeedback = feedbacks.length > 0 ? feedbacks[0] : undefined; return { - actionRef: ActionDefinitionRefUri, - actionDef: ActionDefinition.val, + actionRef: actionDefinitionRefUri, + actionDef: actionDefinition.val, runCount: feedbacks.length, lastRanAt: selectedFeedback ? convertISOToDate(selectedFeedback.feedbackRaw.created_at + 'Z') @@ -157,7 +160,7 @@ export const CallActionsViewer: React.FC<{ }) ?? [] ); }, [ - ActionDefinitions, + actionDefinitions, getFeedbackForAction, props.call.entity, props.call.project, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx index 4a2234c8fb9..1813610798c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx @@ -9,14 +9,13 @@ import _ from 'lodash'; import React, {FC, useEffect, useState} from 'react'; import {z} from 'zod'; -import { - ActionDefinitionSchema, - ActionDefinitionType, - ConfiguredLlmJudgeActionSchema, - ConfiguredWordCountActionSchema, -} from '../../collections/actionCollection'; +import {ConfiguredLlmJudgeActionSchema} from '../../collections/actionCollection'; import {DynamicConfigForm} from '../../DynamicConfigForm'; import {ReusableDrawer} from '../../ReusableDrawer'; +import { + ActionDefinition, + ActionDefinitionSchema, +} from '../wfReactInterface/generatedBaseObjectClasses.zod'; import { actionTemplates, ConfiguredLlmJudgeActionFriendlySchema, @@ -54,26 +53,26 @@ const knownBuiltinActions = [ }, }, }, - { - name: 'Word Count', - actionSchema: ConfiguredWordCountActionSchema, - friendly: { - schema: z.object({}), - convert: ( - data: z.infer - ): z.infer => { - return { - action_type: 'wordcount', - }; - }, - }, - }, + // { + // name: 'Word Count', + // actionSchema: ConfiguredWordCountActionSchema, + // friendly: { + // schema: z.object({}), + // convert: ( + // data: z.infer + // ): z.infer => { + // return { + // action_type: 'wordcount', + // }; + // }, + // }, + // }, ]; interface NewBuiltInActionScorerModalProps { open: boolean; onClose: () => void; - onSave: (newAction: ActionDefinitionType) => void; + onSave: (newAction: ActionDefinition) => void; initialTemplate: string; } diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx index 25fa6cc4f8c..a12167bc865 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx @@ -5,10 +5,10 @@ import MenuItem from '@mui/material/MenuItem'; import {Button} from '@wandb/weave/components/Button/Button'; import React, {FC, useState} from 'react'; -import {ActionDefinitionType} from '../../collections/actionCollection'; -import {useCreateCollectionObject} from '../../collections/getCollectionObjects'; import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; import {FilterableObjectVersionsTable} from '../ObjectVersionsPage'; +import {useCreateBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; +import {ActionDefinition} from '../wfReactInterface/generatedBaseObjectClasses.zod'; import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; import {actionTemplates} from './actionTemplates'; import {NewBuiltInActionScorerModal} from './NewBuiltInActionScorerModal'; @@ -81,7 +81,8 @@ const OnlineScorersTab: React.FC<{ }> = ({entity, project}) => { const [isModalOpen, setIsModalOpen] = useState(false); const [selectedTemplate, setSelectedTemplate] = useState(''); - const createCollectionObject = useCreateCollectionObject('ActionDefinition'); + const createCollectionObject = + useCreateBaseObjectInstance('ActionDefinition'); const [lastUpdatedTimestamp, setLastUpdatedTimestamp] = useState(0); const [anchorEl, setAnchorEl] = React.useState(null); @@ -111,10 +112,11 @@ const OnlineScorersTab: React.FC<{ setSelectedTemplate(''); }; - const handleSaveModal = (newAction: ActionDefinitionType) => { + const handleSaveModal = (newAction: ActionDefinition) => { let objectId = newAction.name; // Remove non alphanumeric characters - objectId = objectId.replace(/[^a-zA-Z0-9]/g, '-'); + // TODO: reconcile this null-name issue + objectId = objectId?.replace(/[^a-zA-Z0-9]/g, '-') ?? ''; createCollectionObject({ obj: { project_id: projectIdFromParts({entity, project}), From 6919cf29895979312df4a6a95b9ca847d5f97a68 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 22:25:16 -0800 Subject: [PATCH 057/120] Fixed a few things --- .../components/FancyPage/useProjectSidebar.ts | 21 ++++++----- .../Browse3/collections/actionCollection.ts | 2 +- .../NewBuiltInActionScorerModal.tsx | 12 +++---- .../Browse3/pages/ScorersPage/ScorersPage.tsx | 35 +++++-------------- .../pages/ScorersPage/actionTemplates.tsx | 6 ++-- 5 files changed, 31 insertions(+), 45 deletions(-) diff --git a/weave-js/src/components/FancyPage/useProjectSidebar.ts b/weave-js/src/components/FancyPage/useProjectSidebar.ts index 45ded11435d..680bd43efba 100644 --- a/weave-js/src/components/FancyPage/useProjectSidebar.ts +++ b/weave-js/src/components/FancyPage/useProjectSidebar.ts @@ -157,6 +157,13 @@ export const useProjectSidebar = ( isShown: isWeaveOnly, iconName: IconNames.BenchmarkSquare, }, + { + type: 'button' as const, + name: 'Scorers', + slug: 'weave/scorers', + isShown: isWeaveOnly, + iconName: IconNames.TypeNumberAlt, + }, { type: 'divider' as const, key: 'dividerWithinWeave-2', @@ -183,13 +190,6 @@ export const useProjectSidebar = ( isShown: showWeaveSidebarItems || isShowAll, iconName: IconNames.Table, }, - { - type: 'button' as const, - name: 'Scorers', - slug: 'weave/scorers', - isShown: showWeaveSidebarItems || isShowAll, - iconName: IconNames.TypeNumberAlt, - }, { type: 'divider' as const, key: 'dividerWithinWeave-3', @@ -218,7 +218,12 @@ export const useProjectSidebar = ( key: 'moreWeave', isShown: isShowAll, // iconName: IconNames.OverflowHorizontal, - menu: ['weave/leaderboards', 'weave/operations', 'weave/objects'], + menu: [ + 'weave/leaderboards', + 'weave/scorers', + 'weave/operations', + 'weave/objects', + ], }, ]; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts index 6a288aa0ca0..e74a2df6f17 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts @@ -12,7 +12,7 @@ export const ConfiguredLlmJudgeActionSchema = z.object({ action_type: z.literal('llm_judge'), model: z.enum(['gpt-4o-mini', 'gpt-4o']).default('gpt-4o-mini'), prompt: z.string(), - response_format: z.discriminatedUnion('type', [ + response_schema: z.discriminatedUnion('type', [ SimpleJsonResponseFormat, ObjectJsonResponseFormat, ]), diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx index 1813610798c..600d564b3ec 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx @@ -32,13 +32,13 @@ const knownBuiltinActions = [ ): z.infer => { let responseFormat: z.infer< typeof ConfiguredLlmJudgeActionSchema - >['response_format']; - if (data.response_format.type === 'simple') { - responseFormat = {type: data.response_format.schema}; + >['response_schema']; + if (data.response_schema.type === 'simple') { + responseFormat = {type: data.response_schema.schema}; } else { responseFormat = { type: 'object', - properties: _.mapValues(data.response_format.schema, value => ({ + properties: _.mapValues(data.response_schema.schema, value => ({ type: value as 'boolean' | 'number' | 'string', })), additionalProperties: false, @@ -48,7 +48,7 @@ const knownBuiltinActions = [ action_type: 'llm_judge', model: data.model, prompt: data.prompt, - response_format: responseFormat, + response_schema: responseFormat, }; }, }, @@ -99,7 +99,7 @@ export const NewBuiltInActionScorerModal: FC< const handleSave = () => { const newAction = ActionDefinitionSchema.parse({ name, - config: knownBuiltinActions[selectedActionIndex].friendly.convert( + spec: knownBuiltinActions[selectedActionIndex].friendly.convert( config as any ), }); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx index a12167bc865..2372c7403a0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx @@ -1,5 +1,4 @@ import {Box} from '@material-ui/core'; -import {Alert} from '@mui/material'; import Menu from '@mui/material/Menu'; import MenuItem from '@mui/material/MenuItem'; import {Button} from '@wandb/weave/components/Button/Button'; @@ -20,16 +19,18 @@ export const ScorersPage: React.FC<{ return ( , }, - { - label: 'Human Review', - content: , - }, + // This is a placeholder for Griffin's annotation column manager section + // { + // label: 'Human Review', + // content: , + // }, { label: 'Code Scorers', content: , @@ -55,26 +56,6 @@ const CodeScorersTab: React.FC<{ ); }; -const HumanScorersTab: React.FC<{ - entity: string; - project: string; -}> = ({entity, project}) => { - return ( - - Human Review coming soon - - ); - // return ( - // - // ); -}; - const OnlineScorersTab: React.FC<{ entity: string; project: string; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionTemplates.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionTemplates.tsx index 51f7ca32bcd..f0f0d515b0d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionTemplates.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionTemplates.tsx @@ -19,7 +19,7 @@ const ResponseFormatSchema = z.discriminatedUnion('type', [ export const ConfiguredLlmJudgeActionFriendlySchema = z.object({ model: z.enum(['gpt-4o-mini', 'gpt-4o']).default('gpt-4o-mini'), prompt: z.string(), - response_format: ResponseFormatSchema, + response_schema: ResponseFormatSchema, }); type ConfiguredLlmJudgeActionFriendlyType = z.infer< typeof ConfiguredLlmJudgeActionFriendlySchema @@ -34,7 +34,7 @@ export const actionTemplates: Array<{ type: { model: 'gpt-4o-mini', prompt: 'Is the output relevant to the input?', - response_format: { + response_schema: { type: 'simple', schema: 'boolean', }, @@ -46,7 +46,7 @@ export const actionTemplates: Array<{ model: 'gpt-4o-mini', prompt: 'Given the input and output, and your knowledge of the world, is the output correct?', - response_format: { + response_schema: { type: 'structured', schema: { is_correct: 'boolean', From 2e832353c3292117e730797de4c87887b364eb16 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 22:35:56 -0800 Subject: [PATCH 058/120] Fix action executor --- .../Home/Browse3/pages/CallPage/CallActionsViewer.tsx | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx index 18feb056667..47edd7df4eb 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx @@ -36,8 +36,9 @@ const RunButton: React.FC<{ entity: string; project: string; refetchFeedback: () => void; - getClient: () => any; -}> = ({actionRef, callId, entity, project, refetchFeedback, getClient}) => { +}> = ({actionRef, callId, entity, project, refetchFeedback}) => { + const getClient = useGetTraceServerClientContext(); + const [isRunning, setIsRunning] = useState(false); const [error, setError] = useState(null); @@ -48,7 +49,7 @@ const RunButton: React.FC<{ await getClient().actionsExecuteBatch({ project_id: projectIdFromParts({entity, project}), call_ids: [callId], - configured_action_ref: actionRef, + action_ref: actionRef, }); refetchFeedback(); } catch (err) { @@ -129,7 +130,6 @@ export const CallActionsViewer: React.FC<{ [verifiedActionFeedbacks] ); - const getClient = useGetTraceServerClientContext(); const allCallActions: CallActionRow[] = useMemo(() => { return ( @@ -206,7 +206,6 @@ export const CallActionsViewer: React.FC<{ entity={props.call.entity} project={props.call.project} refetchFeedback={feedbackQuery.refetch} - getClient={getClient} /> ), }, From 546e96adc76b324b4db1d9a3927e4169e81c1e9e Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 5 Nov 2024 22:44:52 -0800 Subject: [PATCH 059/120] Call Action Viewer complete --- .../pages/CallPage/CallActionsViewer.tsx | 49 +++---------------- .../traceServerClientTypes.ts | 1 + 2 files changed, 8 insertions(+), 42 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx index 47edd7df4eb..3be291fb12e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx @@ -1,9 +1,7 @@ import {Button} from '@wandb/weave/components/Button/Button'; import {Timestamp} from '@wandb/weave/components/Timestamp'; -import {parseRef} from '@wandb/weave/react'; import {makeRefCall} from '@wandb/weave/util/refs'; import React, {useCallback, useMemo, useState} from 'react'; -import {z} from 'zod'; import {CellValue} from '../../../Browse2/CellValue'; import {NotApplicable} from '../../../Browse2/NotApplicable'; @@ -101,35 +99,25 @@ export const CallActionsViewer: React.FC<{ }).result ?? [] ).sort((a, b) => (a.val.name ?? '').localeCompare(b.val.name ?? '')); const verifiedActionFeedbacks: Array<{ - data: MachineScoreFeedbackPayloadType; + data: any; feedbackRaw: Feedback; }> = useMemo(() => { return (feedbackQuery.result ?? []) + .filter(f => f.feedback_type?.startsWith('wandb.runnable')) .map(feedback => { - const res = MachineScoreFeedbackPayloadSchema.safeParse( - feedback.payload - ); - return {res, feedbackRaw: feedback}; - }) - .filter(result => result.res.success) - .map(result => ({ - data: result.res.data, - feedbackRaw: result.feedbackRaw, - })) as Array<{ - data: MachineScoreFeedbackPayloadType; - feedbackRaw: Feedback; - }>; + return {data: feedback.payload.output, feedbackRaw: feedback}; + }); }, [feedbackQuery.result]); const getFeedbackForAction = useCallback( (actionRef: string) => { return verifiedActionFeedbacks.filter( - feedback => feedback.data.runnable_ref === actionRef + feedback => feedback.feedbackRaw.runnable_ref === actionRef ); }, [verifiedActionFeedbacks] ); - + console.log(verifiedActionFeedbacks); const allCallActions: CallActionRow[] = useMemo(() => { return ( @@ -153,9 +141,7 @@ export const CallActionsViewer: React.FC<{ lastRanAt: selectedFeedback ? convertISOToDate(selectedFeedback.feedbackRaw.created_at + 'Z') : undefined, - lastResult: selectedFeedback - ? getValueFromMachineScoreFeedbackPayload(selectedFeedback.data) - : undefined, + lastResult: selectedFeedback ? selectedFeedback.data : undefined, }; }) ?? [] ); @@ -256,24 +242,3 @@ export const CallActionsViewer: React.FC<{ ); }; - -const MachineScoreFeedbackPayloadSchema = z.object({ - // _type: z.literal("ActionFeedback"), - runnable_ref: z.string(), - call_ref: z.string().optional(), - trigger_ref: z.string().optional(), - value: z.record(z.string(), z.record(z.string(), z.boolean())), -}); - -type MachineScoreFeedbackPayloadType = z.infer< - typeof MachineScoreFeedbackPayloadSchema ->; - -const getValueFromMachineScoreFeedbackPayload = ( - payload: MachineScoreFeedbackPayloadType -) => { - const ref = parseRef(payload.runnable_ref); - const name = ref.artifactName; - const digest = ref.artifactVersion; - return payload.value[name][digest]; -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts index bcc68300a4f..6e5fd393ef7 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts @@ -171,6 +171,7 @@ export type Feedback = { created_at: string; feedback_type: string; payload: Record; + runnable_ref?: string; }; export type FeedbackQuerySuccess = { From 3b852864bfea07ac81f0ecfa9d2ca16193cc25fe Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 6 Nov 2024 13:52:12 -0800 Subject: [PATCH 060/120] Call Action Viewer complete --- .../Browse3/collections/actionCollection.ts | 51 ------------------- .../NewBuiltInActionScorerModal.tsx | 22 ++------ .../wfReactInterface/baseObjectClasses.zod.ts | 25 +++++++++ .../generatedBaseObjectClasses.zod.ts | 34 ++++--------- 4 files changed, 40 insertions(+), 92 deletions(-) delete mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClasses.zod.ts diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts deleted file mode 100644 index e74a2df6f17..00000000000 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/collections/actionCollection.ts +++ /dev/null @@ -1,51 +0,0 @@ -import {z} from 'zod'; - -const JSONTypeNames = z.enum(['boolean', 'number', 'string']); -const SimpleJsonResponseFormat = z.object({type: JSONTypeNames}); -const ObjectJsonResponseFormat = z.object({ - type: z.literal('object'), - properties: z.record(SimpleJsonResponseFormat), - additionalProperties: z.literal(false), -}); - -export const ConfiguredLlmJudgeActionSchema = z.object({ - action_type: z.literal('llm_judge'), - model: z.enum(['gpt-4o-mini', 'gpt-4o']).default('gpt-4o-mini'), - prompt: z.string(), - response_schema: z.discriminatedUnion('type', [ - SimpleJsonResponseFormat, - ObjectJsonResponseFormat, - ]), -}); -export type ConfiguredLlmJudgeActionType = z.infer< - typeof ConfiguredLlmJudgeActionSchema ->; - -// export const ConfiguredWordCountActionSchema = z.object({ -// action_type: z.literal('wordcount'), -// }); -// export type ConfiguredWordCountActionType = z.infer< -// typeof ConfiguredWordCountActionSchema -// >; - -// export const ActionConfigSchema = z.discriminatedUnion('action_type', [ -// ConfiguredLlmJudgeActionSchema, -// ConfiguredWordCountActionSchema, -// ]); -// export type ActionConfigType = z.infer; - -// export const ActionDefinitionSchema = z.object({ -// name: z.string(), -// config: ActionConfigSchema, -// }); -// export type ActionDefinitionType = z.infer; - -// export const ActionDispatchFilterSchema = z.object({ -// op_name: z.string(), -// sample_rate: z.number().min(0).max(1).default(1), -// configured_action_ref: z.string(), -// disabled: z.boolean().optional(), -// }); -// export type ActionDispatchFilterType = z.infer< -// typeof ActionDispatchFilterSchema -// >; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx index 600d564b3ec..d32066ee387 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx @@ -9,9 +9,9 @@ import _ from 'lodash'; import React, {FC, useEffect, useState} from 'react'; import {z} from 'zod'; -import {ConfiguredLlmJudgeActionSchema} from '../../collections/actionCollection'; import {DynamicConfigForm} from '../../DynamicConfigForm'; import {ReusableDrawer} from '../../ReusableDrawer'; +import {LlmJudgeActionSpecSchema} from '../wfReactInterface/baseObjectClasses.zod'; import { ActionDefinition, ActionDefinitionSchema, @@ -24,14 +24,14 @@ import { const knownBuiltinActions = [ { name: 'LLM Judge', - actionSchema: ConfiguredLlmJudgeActionSchema, + actionSchema: LlmJudgeActionSpecSchema, friendly: { schema: ConfiguredLlmJudgeActionFriendlySchema, convert: ( data: z.infer - ): z.infer => { + ): z.infer => { let responseFormat: z.infer< - typeof ConfiguredLlmJudgeActionSchema + typeof LlmJudgeActionSpecSchema >['response_schema']; if (data.response_schema.type === 'simple') { responseFormat = {type: data.response_schema.schema}; @@ -53,20 +53,6 @@ const knownBuiltinActions = [ }, }, }, - // { - // name: 'Word Count', - // actionSchema: ConfiguredWordCountActionSchema, - // friendly: { - // schema: z.object({}), - // convert: ( - // data: z.infer - // ): z.infer => { - // return { - // action_type: 'wordcount', - // }; - // }, - // }, - // }, ]; interface NewBuiltInActionScorerModalProps { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClasses.zod.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClasses.zod.ts new file mode 100644 index 00000000000..e7f431f7150 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClasses.zod.ts @@ -0,0 +1,25 @@ +/** + * This file contains zod schemas for our baseObjectClasses that are not + * correctly / completely generated by the json-schema to zod converter. + */ + +import * as z from 'zod'; + +const JSONTypeNames = z.enum(['boolean', 'number', 'string']); +const SimpleJsonResponseFormat = z.object({type: JSONTypeNames}); +const ObjectJsonResponseFormat = z.object({ + type: z.literal('object'), + properties: z.record(SimpleJsonResponseFormat), + additionalProperties: z.literal(false), +}); + +export const LlmJudgeActionSpecSchema = z.object({ + action_type: z.literal('llm_judge'), + model: z.enum(['gpt-4o-mini', 'gpt-4o']).default('gpt-4o-mini'), + prompt: z.string(), + response_schema: z.discriminatedUnion('type', [ + SimpleJsonResponseFormat, + ObjectJsonResponseFormat, + ]), +}); +export type LlmJudgeActionSpec = z.infer; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBaseObjectClasses.zod.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBaseObjectClasses.zod.ts index 10ba7842599..cee00d88c17 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBaseObjectClasses.zod.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBaseObjectClasses.zod.ts @@ -1,31 +1,19 @@ import * as z from 'zod'; -// BEGINNING OF CUSTOM CODE ///// -// Sadly, the json-schema to zod converter doesn't support discriminator -// so we have to define the schemas manually. If you run the generator -// make sure to review the changes to this section. -export const LlmJudgeActionSpecSchema = z.object({ - action_type: z.enum(['llm_judge']), - model: z.enum(['gpt-4o', 'gpt-4o-mini']), - prompt: z.string(), - response_schema: z.record(z.string(), z.any()), -}); -export type LlmJudgeActionSpec = z.infer; +export const ActionTypeSchema = z.enum(['contains_words', 'llm_judge']); +export type ActionType = z.infer; -export const ContainsWordsActionSpecSchema = z.object({ - action_type: z.enum(['contains_words']), - target_words: z.array(z.string()), -}); -export type ContainsWordsActionSpec = z.infer< - typeof ContainsWordsActionSpecSchema ->; +export const ModelSchema = z.enum(['gpt-4o', 'gpt-4o-mini']); +export type Model = z.infer; -export const SpecSchema = z.discriminatedUnion('action_type', [ - LlmJudgeActionSpecSchema, - ContainsWordsActionSpecSchema, -]); +export const SpecSchema = z.object({ + action_type: ActionTypeSchema.optional(), + model: ModelSchema.optional(), + prompt: z.string().optional(), + response_schema: z.record(z.string(), z.any()).optional(), + target_words: z.array(z.string()).optional(), +}); export type Spec = z.infer; -// END OF CUSTOM CODE ///// export const LeaderboardColumnSchema = z.object({ evaluation_object_ref: z.string(), From fb91c210bb1bee00a3f99fbdba0f13fa964b303f Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 6 Nov 2024 14:06:23 -0800 Subject: [PATCH 061/120] Call Action Viewer complete --- .../NewBuiltInActionScorerModal.tsx | 48 ++----------------- ...actionTemplates.tsx => actionTemplates.ts} | 0 .../pages/ScorersPage/builtinActions.ts | 46 ++++++++++++++++++ 3 files changed, 50 insertions(+), 44 deletions(-) rename weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/{actionTemplates.tsx => actionTemplates.ts} (100%) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/builtinActions.ts diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx index d32066ee387..0bedb91ff0f 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx @@ -7,53 +7,15 @@ import { } from '@material-ui/core'; import _ from 'lodash'; import React, {FC, useEffect, useState} from 'react'; -import {z} from 'zod'; import {DynamicConfigForm} from '../../DynamicConfigForm'; import {ReusableDrawer} from '../../ReusableDrawer'; -import {LlmJudgeActionSpecSchema} from '../wfReactInterface/baseObjectClasses.zod'; import { ActionDefinition, ActionDefinitionSchema, } from '../wfReactInterface/generatedBaseObjectClasses.zod'; -import { - actionTemplates, - ConfiguredLlmJudgeActionFriendlySchema, -} from './actionTemplates'; - -const knownBuiltinActions = [ - { - name: 'LLM Judge', - actionSchema: LlmJudgeActionSpecSchema, - friendly: { - schema: ConfiguredLlmJudgeActionFriendlySchema, - convert: ( - data: z.infer - ): z.infer => { - let responseFormat: z.infer< - typeof LlmJudgeActionSpecSchema - >['response_schema']; - if (data.response_schema.type === 'simple') { - responseFormat = {type: data.response_schema.schema}; - } else { - responseFormat = { - type: 'object', - properties: _.mapValues(data.response_schema.schema, value => ({ - type: value as 'boolean' | 'number' | 'string', - })), - additionalProperties: false, - }; - } - return { - action_type: 'llm_judge', - model: data.model, - prompt: data.prompt, - response_schema: responseFormat, - }; - }, - }, - }, -]; +import {actionTemplates} from './actionTemplates'; +import {knownBuiltinActions} from './builtinActions'; interface NewBuiltInActionScorerModalProps { open: boolean; @@ -85,9 +47,7 @@ export const NewBuiltInActionScorerModal: FC< const handleSave = () => { const newAction = ActionDefinitionSchema.parse({ name, - spec: knownBuiltinActions[selectedActionIndex].friendly.convert( - config as any - ), + spec: knownBuiltinActions[selectedActionIndex].convert(config as any), }); onSave(newAction); setConfig({}); @@ -128,7 +88,7 @@ export const NewBuiltInActionScorerModal: FC< {selectedActionIndex !== -1 && ( = { + name: string; + actionSchema: A; + inputFriendlySchema: F; + convert: (data: z.infer) => z.infer; +}; + +export const knownBuiltinActions: KnownBuiltingAction[] = [ + { + name: 'LLM Judge', + actionSchema: LlmJudgeActionSpecSchema, + inputFriendlySchema: ConfiguredLlmJudgeActionFriendlySchema, + convert: ( + data: z.infer + ): z.infer => { + let responseFormat: z.infer< + typeof LlmJudgeActionSpecSchema + >['response_schema']; + if (data.response_schema.type === 'simple') { + responseFormat = {type: data.response_schema.schema}; + } else { + responseFormat = { + type: 'object', + properties: _.mapValues(data.response_schema.schema, value => ({ + type: value as 'boolean' | 'number' | 'string', + })), + additionalProperties: false, + }; + } + return { + action_type: 'llm_judge', + model: data.model, + prompt: data.prompt, + response_schema: responseFormat, + }; + }, + }, +]; From 68352df1cde850e0d77587e946a849334c38efcb Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 6 Nov 2024 14:49:16 -0800 Subject: [PATCH 062/120] Fixed up names --- .../NewBuiltInActionScorerModal.tsx | 52 ++++---- .../Browse3/pages/ScorersPage/ScorersPage.tsx | 45 ++++--- .../actionDefinitionConfigurationSpecs.ts | 118 ++++++++++++++++++ .../pages/ScorersPage/actionTemplates.ts | 58 --------- .../pages/ScorersPage/builtinActions.ts | 46 ------- 5 files changed, 178 insertions(+), 141 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionDefinitionConfigurationSpecs.ts delete mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionTemplates.ts delete mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/builtinActions.ts diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx index 0bedb91ff0f..7598f1c1ff3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx @@ -5,7 +5,6 @@ import { Select, TextField, } from '@material-ui/core'; -import _ from 'lodash'; import React, {FC, useEffect, useState} from 'react'; import {DynamicConfigForm} from '../../DynamicConfigForm'; @@ -13,31 +12,35 @@ import {ReusableDrawer} from '../../ReusableDrawer'; import { ActionDefinition, ActionDefinitionSchema, + ActionType, } from '../wfReactInterface/generatedBaseObjectClasses.zod'; -import {actionTemplates} from './actionTemplates'; -import {knownBuiltinActions} from './builtinActions'; +import {actionDefinitionConfigurationSpecs} from './actionDefinitionConfigurationSpecs'; interface NewBuiltInActionScorerModalProps { open: boolean; onClose: () => void; onSave: (newAction: ActionDefinition) => void; - initialTemplate: string; + initialTemplate?: { + actionType: ActionType; + template: {name: string; config: Record}; + } | null; } export const NewBuiltInActionScorerModal: FC< NewBuiltInActionScorerModalProps > = ({open, onClose, onSave, initialTemplate}) => { const [name, setName] = useState(''); - const [selectedActionIndex, setSelectedActionIndex] = useState(0); + const [selectedActionType, setSelectedActionType] = + useState('llm_judge'); const [config, setConfig] = useState>({}); + const selectedActionDefinitionConfigurationSpec = + actionDefinitionConfigurationSpecs[selectedActionType]; useEffect(() => { if (initialTemplate) { - const template = actionTemplates.find(t => t.name === initialTemplate); - if (template) { - setConfig(template.type); - setName(template.name); - } + setConfig(initialTemplate.template.config); + setSelectedActionType(initialTemplate.actionType); + setName(initialTemplate.template.name); } else { setConfig({}); setName(''); @@ -45,13 +48,16 @@ export const NewBuiltInActionScorerModal: FC< }, [initialTemplate]); const handleSave = () => { + if (!selectedActionDefinitionConfigurationSpec) { + return; + } const newAction = ActionDefinitionSchema.parse({ name, - spec: knownBuiltinActions[selectedActionIndex].convert(config as any), + spec: selectedActionDefinitionConfigurationSpec.convert(config as any), }); onSave(newAction); setConfig({}); - setSelectedActionIndex(0); + setSelectedActionType('llm_judge'); setName(''); }; @@ -74,21 +80,21 @@ export const NewBuiltInActionScorerModal: FC< Action Type - {selectedActionIndex !== -1 && ( + {selectedActionDefinitionConfigurationSpec && ( = ({entity, project}) => { const [isModalOpen, setIsModalOpen] = useState(false); - const [selectedTemplate, setSelectedTemplate] = useState(''); + const [selectedTemplate, setSelectedTemplate] = useState<{ + actionType: ActionType; + template: {name: string; config: Record}; + } | null>(null); const createCollectionObject = useCreateBaseObjectInstance('ActionDefinition'); const [lastUpdatedTimestamp, setLastUpdatedTimestamp] = useState(0); @@ -70,7 +76,7 @@ const OnlineScorersTab: React.FC<{ const open = Boolean(anchorEl); const handleCreateBlank = () => { - setSelectedTemplate(''); + setSelectedTemplate(null); setIsModalOpen(true); }; @@ -82,15 +88,18 @@ const OnlineScorersTab: React.FC<{ setAnchorEl(null); }; - const handleTemplateSelect = (templateName: string) => { - setSelectedTemplate(templateName); + const handleTemplateSelect = (template: { + actionType: ActionType; + template: {name: string; config: Record}; + }) => { + setSelectedTemplate(template); setIsModalOpen(true); handleClose(); }; const handleCloseModal = () => { setIsModalOpen(false); - setSelectedTemplate(''); + setSelectedTemplate(null); }; const handleSaveModal = (newAction: ActionDefinition) => { @@ -150,13 +159,21 @@ const OnlineScorersTab: React.FC<{ /> - {actionTemplates.map(template => ( - handleTemplateSelect(template.name)}> - {template.name} - - ))} + {Object.entries(actionDefinitionConfigurationSpecs).flatMap( + ([actionType, spec]) => + spec.templates.map(template => ( + + handleTemplateSelect({ + actionType: actionType as ActionType, + template, + }) + }> + {template.name} + + )) + )} = { + name: string; + actionSchema: A; + inputFriendlySchema: F; + convert: (data: z.infer) => z.infer; + templates: Array<{ + name: string; + config: z.infer; + }>; +}; + +export const actionDefinitionConfigurationSpecs: Partial< + Record +> = { + llm_judge: { + name: 'LLM Judge', + actionSchema: LlmJudgeActionSpecSchema, + inputFriendlySchema: ConfiguredLlmJudgeActionFriendlySchema, + convert: ( + data: z.infer + ): z.infer => { + let responseFormat: z.infer< + typeof LlmJudgeActionSpecSchema + >['response_schema']; + if (data.response_schema.type === 'simple') { + responseFormat = {type: data.response_schema.schema}; + } else { + responseFormat = { + type: 'object', + properties: _.mapValues(data.response_schema.schema, value => ({ + type: value as 'boolean' | 'number' | 'string', + })), + additionalProperties: false, + }; + } + return { + action_type: 'llm_judge', + model: data.model, + prompt: data.prompt, + response_schema: responseFormat, + }; + }, + templates: [ + { + name: 'RelevancyJudge', + config: { + model: 'gpt-4o-mini', + prompt: 'Is the output relevant to the input?', + response_schema: { + type: 'simple', + schema: 'boolean', + }, + }, + }, + { + name: 'CorrectnessJudge', + config: { + model: 'gpt-4o-mini', + prompt: + 'Given the input and output, and your knowledge of the world, is the output correct?', + response_schema: { + type: 'structured', + schema: { + is_correct: 'boolean', + reason: 'string', + }, + }, + }, + }, + ], + }, +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionTemplates.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionTemplates.ts deleted file mode 100644 index f0f0d515b0d..00000000000 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionTemplates.ts +++ /dev/null @@ -1,58 +0,0 @@ -import {z} from 'zod'; - -const SimpleResponseFormatSchema = z - .enum(['boolean', 'number', 'string']) - .default('boolean'); -const StructuredResponseFormatSchema = z.record(SimpleResponseFormatSchema); - -const ResponseFormatSchema = z.discriminatedUnion('type', [ - z.object({ - type: z.literal('simple'), - schema: SimpleResponseFormatSchema, - }), - z.object({ - type: z.literal('structured'), - schema: StructuredResponseFormatSchema, - }), -]); - -export const ConfiguredLlmJudgeActionFriendlySchema = z.object({ - model: z.enum(['gpt-4o-mini', 'gpt-4o']).default('gpt-4o-mini'), - prompt: z.string(), - response_schema: ResponseFormatSchema, -}); -type ConfiguredLlmJudgeActionFriendlyType = z.infer< - typeof ConfiguredLlmJudgeActionFriendlySchema ->; - -export const actionTemplates: Array<{ - name: string; - type: ConfiguredLlmJudgeActionFriendlyType; -}> = [ - { - name: 'RelevancyJudge', - type: { - model: 'gpt-4o-mini', - prompt: 'Is the output relevant to the input?', - response_schema: { - type: 'simple', - schema: 'boolean', - }, - }, - }, - { - name: 'CorrectnessJudge', - type: { - model: 'gpt-4o-mini', - prompt: - 'Given the input and output, and your knowledge of the world, is the output correct?', - response_schema: { - type: 'structured', - schema: { - is_correct: 'boolean', - reason: 'string', - }, - }, - }, - }, -]; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/builtinActions.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/builtinActions.ts deleted file mode 100644 index 740e05b0a90..00000000000 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/builtinActions.ts +++ /dev/null @@ -1,46 +0,0 @@ -import {z} from 'zod'; - -import {LlmJudgeActionSpecSchema} from '../wfReactInterface/baseObjectClasses.zod'; -import {ConfiguredLlmJudgeActionFriendlySchema} from './actionTemplates'; - -type KnownBuiltingAction< - A extends z.ZodTypeAny = z.ZodTypeAny, - F extends z.ZodTypeAny = z.ZodTypeAny -> = { - name: string; - actionSchema: A; - inputFriendlySchema: F; - convert: (data: z.infer) => z.infer; -}; - -export const knownBuiltinActions: KnownBuiltingAction[] = [ - { - name: 'LLM Judge', - actionSchema: LlmJudgeActionSpecSchema, - inputFriendlySchema: ConfiguredLlmJudgeActionFriendlySchema, - convert: ( - data: z.infer - ): z.infer => { - let responseFormat: z.infer< - typeof LlmJudgeActionSpecSchema - >['response_schema']; - if (data.response_schema.type === 'simple') { - responseFormat = {type: data.response_schema.schema}; - } else { - responseFormat = { - type: 'object', - properties: _.mapValues(data.response_schema.schema, value => ({ - type: value as 'boolean' | 'number' | 'string', - })), - additionalProperties: false, - }; - } - return { - action_type: 'llm_judge', - model: data.model, - prompt: data.prompt, - response_schema: responseFormat, - }; - }, - }, -]; From 901a788bab9afa2c616742cf8e10eeff5728c92d Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 6 Nov 2024 14:53:35 -0800 Subject: [PATCH 063/120] Fixed --- ...onScorerModal.tsx => NewActionDefinitionModal.tsx} | 11 +++++++---- .../Home/Browse3/pages/ScorersPage/ScorersPage.tsx | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) rename weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/{NewBuiltInActionScorerModal.tsx => NewActionDefinitionModal.tsx} (94%) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionDefinitionModal.tsx similarity index 94% rename from weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx rename to weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionDefinitionModal.tsx index 7598f1c1ff3..e993312aa3d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewBuiltInActionScorerModal.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionDefinitionModal.tsx @@ -16,7 +16,7 @@ import { } from '../wfReactInterface/generatedBaseObjectClasses.zod'; import {actionDefinitionConfigurationSpecs} from './actionDefinitionConfigurationSpecs'; -interface NewBuiltInActionScorerModalProps { +interface NewActionDefinitionModalProps { open: boolean; onClose: () => void; onSave: (newAction: ActionDefinition) => void; @@ -26,9 +26,12 @@ interface NewBuiltInActionScorerModalProps { } | null; } -export const NewBuiltInActionScorerModal: FC< - NewBuiltInActionScorerModalProps -> = ({open, onClose, onSave, initialTemplate}) => { +export const NewActionDefinitionModal: FC = ({ + open, + onClose, + onSave, + initialTemplate, +}) => { const [name, setName] = useState(''); const [selectedActionType, setSelectedActionType] = useState('llm_judge'); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx index 7776cd8fe63..de666af7bdf 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx @@ -13,7 +13,7 @@ import { } from '../wfReactInterface/generatedBaseObjectClasses.zod'; import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; import {actionDefinitionConfigurationSpecs} from './actionDefinitionConfigurationSpecs'; -import {NewBuiltInActionScorerModal} from './NewBuiltInActionScorerModal'; +import {NewActionDefinitionModal} from './NewActionDefinitionModal'; export const ScorersPage: React.FC<{ entity: string; @@ -184,7 +184,7 @@ const OnlineScorersTab: React.FC<{ baseObjectClass: 'ActionDefinition', }} /> - Date: Wed, 6 Nov 2024 15:03:39 -0800 Subject: [PATCH 064/120] Fixed --- .../Home/Browse3/pages/ObjectVersionsPage.tsx | 3 +++ .../Browse3/pages/ScorersPage/ScorersPage.tsx | 2 +- .../Home/Browse3/pages/common/EmptyContent.tsx | 16 ++++++++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx index eed18b9a70c..2cb8595655d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx @@ -28,6 +28,7 @@ import {StyledDataGrid} from '../StyledDataGrid'; import {basicField} from './common/DataTable'; import {Empty} from './common/Empty'; import { + EMPTY_PROPS_PROGRAMATIC_SCORERS, EMPTY_PROPS_DATASETS, EMPTY_PROPS_LEADERBOARDS, EMPTY_PROPS_MODEL, @@ -167,6 +168,8 @@ export const FilterableObjectVersionsTable: React.FC<{ propsEmpty = EMPTY_PROPS_DATASETS; } else if (base === 'Leaderboard') { propsEmpty = EMPTY_PROPS_LEADERBOARDS; + } else if (base === 'Scorer') { + propsEmpty = EMPTY_PROPS_PROGRAMATIC_SCORERS; } return ; } diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx index de666af7bdf..6d0c123b220 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx @@ -35,7 +35,7 @@ export const ScorersPage: React.FC<{ // content: , // }, { - label: 'Code Scorers', + label: 'Programmatic Scorers', content: , }, ]} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx index f2db0da100b..936f88382df 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx @@ -187,3 +187,19 @@ export const EMPTY_NO_TRACE_SERVER: EmptyProps = { ), }; + +export const EMPTY_PROPS_PROGRAMATIC_SCORERS: EmptyProps = { + icon: 'type-number-alt' as const, + heading: 'No programmatic scorers yet', + description: + 'Create programmatic scorers in Python.', + moreInformation: ( + <> + Learn more about{' '} + + creating and using scorers + {' '} + in evaluations. + + ), +}; \ No newline at end of file From 16a96a1eeb557947350982d1241c98493342ef81 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 6 Nov 2024 15:12:38 -0800 Subject: [PATCH 065/120] Added empty states --- .../Home/Browse3/pages/ObjectVersionsPage.tsx | 5 ++++- .../Home/Browse3/pages/ScorersPage/ScorersPage.tsx | 6 +++--- .../Home/Browse3/pages/common/EmptyContent.tsx | 14 +++++++++++++- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx index 2cb8595655d..c6b6353caf1 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx @@ -28,11 +28,12 @@ import {StyledDataGrid} from '../StyledDataGrid'; import {basicField} from './common/DataTable'; import {Empty} from './common/Empty'; import { - EMPTY_PROPS_PROGRAMATIC_SCORERS, + EMPTY_PROPS_ACTION_DEFINITIONS, EMPTY_PROPS_DATASETS, EMPTY_PROPS_LEADERBOARDS, EMPTY_PROPS_MODEL, EMPTY_PROPS_OBJECTS, + EMPTY_PROPS_PROGRAMATIC_SCORERS, EMPTY_PROPS_PROMPTS, } from './common/EmptyContent'; import { @@ -170,6 +171,8 @@ export const FilterableObjectVersionsTable: React.FC<{ propsEmpty = EMPTY_PROPS_LEADERBOARDS; } else if (base === 'Scorer') { propsEmpty = EMPTY_PROPS_PROGRAMATIC_SCORERS; + } else if (base === 'ActionDefinition') { + propsEmpty = EMPTY_PROPS_ACTION_DEFINITIONS; } return ; } diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx index 6d0c123b220..319b42efaab 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx @@ -26,8 +26,8 @@ export const ScorersPage: React.FC<{ { // It is true that this panel can show more than LLM Judges, but the // branding is better - label: 'Configurable LLM Judges', - content: , + label: 'Configurable Judges', + content: , }, // This is a placeholder for Griffin's annotation column manager section // { @@ -59,7 +59,7 @@ const CodeScorersTab: React.FC<{ ); }; -const OnlineScorersTab: React.FC<{ +const ActionDefinitionsTab: React.FC<{ entity: string; project: string; }> = ({entity, project}) => { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx index 936f88382df..713d26a7f93 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx @@ -202,4 +202,16 @@ export const EMPTY_PROPS_PROGRAMATIC_SCORERS: EmptyProps = { in evaluations. ), -}; \ No newline at end of file +}; + +export const EMPTY_PROPS_ACTION_DEFINITIONS: EmptyProps = { + icon: 'automation-robot-arm' as const, + heading: 'No configurations yet', + description: + 'Create new configuration by clicking "Create new" in the top right.', + moreInformation: ( + <> + + + ), +}; From dfb120e0cb82916d02da843605f1c2390b5ea052 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 6 Nov 2024 15:27:10 -0800 Subject: [PATCH 066/120] Fix typing --- .../Browse3/pages/ScorersPage/NewActionDefinitionModal.tsx | 4 ++-- .../pages/ScorersPage/actionDefinitionConfigurationSpecs.ts | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionDefinitionModal.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionDefinitionModal.tsx index e993312aa3d..8225abe9907 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionDefinitionModal.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionDefinitionModal.tsx @@ -32,7 +32,7 @@ export const NewActionDefinitionModal: FC = ({ onSave, initialTemplate, }) => { - const [name, setName] = useState(''); + const [name, setName] = useState(''); const [selectedActionType, setSelectedActionType] = useState('llm_judge'); const [config, setConfig] = useState>({}); @@ -72,7 +72,7 @@ export const NewActionDefinitionModal: FC = ({ title="Configure Scorer" onClose={onClose} onSave={handleSave} - saveDisabled={!isValid}> + saveDisabled={!isValid || name === ''}> Date: Wed, 6 Nov 2024 15:28:43 -0800 Subject: [PATCH 067/120] lint --- .../Home/Browse3/pages/common/EmptyContent.tsx | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx index 713d26a7f93..41cb0eee54e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx @@ -191,13 +191,12 @@ export const EMPTY_NO_TRACE_SERVER: EmptyProps = { export const EMPTY_PROPS_PROGRAMATIC_SCORERS: EmptyProps = { icon: 'type-number-alt' as const, heading: 'No programmatic scorers yet', - description: - 'Create programmatic scorers in Python.', + description: 'Create programmatic scorers in Python.', moreInformation: ( <> Learn more about{' '} - creating and using scorers + creating and using scorers {' '} in evaluations. @@ -209,9 +208,5 @@ export const EMPTY_PROPS_ACTION_DEFINITIONS: EmptyProps = { heading: 'No configurations yet', description: 'Create new configuration by clicking "Create new" in the top right.', - moreInformation: ( - <> - - - ), + moreInformation: <>, }; From 73f3c78f566a3f2ba53feff0561f1f2addeac512 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 6 Nov 2024 16:23:41 -0800 Subject: [PATCH 068/120] Added more --- .../Home/Browse3/pages/CallPage/CallPage.tsx | 4 +- ...ActionsViewer.tsx => CallScoresViewer.tsx} | 181 ++++++++++-------- 2 files changed, 99 insertions(+), 86 deletions(-) rename weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/{CallActionsViewer.tsx => CallScoresViewer.tsx} (61%) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index 4675d5b3e27..c849637a32b 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -28,10 +28,10 @@ import {TabUseCall} from '../TabUseCall'; import {useURLSearchParamsDict} from '../util'; import {useWFHooks} from '../wfReactInterface/context'; import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; -import {CallActionsViewer} from './CallActionsViewer'; import {CallChat} from './CallChat'; import {CallDetails} from './CallDetails'; import {CallOverview} from './CallOverview'; +import {CallScoresViewer} from './CallScoresViewer'; import {CallSummary} from './CallSummary'; import {CallTraceView, useCallFlattenedTraceTree} from './CallTraceView'; @@ -133,7 +133,7 @@ const useCallTabs = (call: CallSchema) => { label: 'Scores', content: ( - + ), }, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx similarity index 61% rename from weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx rename to weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx index 3be291fb12e..9cd71d08ade 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx @@ -1,32 +1,25 @@ +import {Box} from '@material-ui/core'; +import {GridColDef} from '@mui/x-data-grid-pro'; import {Button} from '@wandb/weave/components/Button/Button'; import {Timestamp} from '@wandb/weave/components/Timestamp'; +import {parseRef} from '@wandb/weave/react'; import {makeRefCall} from '@wandb/weave/util/refs'; -import React, {useCallback, useMemo, useState} from 'react'; +import _ from 'lodash'; +import React, {useMemo, useState} from 'react'; import {CellValue} from '../../../Browse2/CellValue'; import {NotApplicable} from '../../../Browse2/NotApplicable'; -// import {ActionDefinitionType} from '../../collections/actionCollection'; +import {SmallRef} from '../../../Browse2/SmallRef'; import {StyledDataGrid} from '../../StyledDataGrid'; // Import the StyledDataGrid component import {useBaseObjectInstances} from '../wfReactInterface/baseObjectClassQuery'; import {WEAVE_REF_SCHEME} from '../wfReactInterface/constants'; import {useWFHooks} from '../wfReactInterface/context'; -import {ActionDefinition} from '../wfReactInterface/generatedBaseObjectClasses.zod'; import {useGetTraceServerClientContext} from '../wfReactInterface/traceServerClientContext'; import {Feedback} from '../wfReactInterface/traceServerClientTypes'; -import { - convertISOToDate, - projectIdFromParts, -} from '../wfReactInterface/tsDataModelHooks'; +import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; import {objectVersionKeyToRefUri} from '../wfReactInterface/utilities'; import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; -type CallActionRow = { - actionRef: string; - actionDef: ActionDefinition; - runCount: number; - lastResult?: unknown; - lastRanAt?: Date; -}; // New RunButton component const RunButton: React.FC<{ actionRef: string; @@ -74,7 +67,7 @@ const RunButton: React.FC<{ ); }; -export const CallActionsViewer: React.FC<{ +export const CallScoresViewer: React.FC<{ call: CallSchema; }> = props => { const {useFeedback} = useWFHooks(); @@ -98,31 +91,11 @@ export const CallActionsViewer: React.FC<{ filter: {latest_only: true}, }).result ?? [] ).sort((a, b) => (a.val.name ?? '').localeCompare(b.val.name ?? '')); - const verifiedActionFeedbacks: Array<{ - data: any; - feedbackRaw: Feedback; - }> = useMemo(() => { - return (feedbackQuery.result ?? []) - .filter(f => f.feedback_type?.startsWith('wandb.runnable')) - .map(feedback => { - return {data: feedback.payload.output, feedbackRaw: feedback}; - }); - }, [feedbackQuery.result]); - - const getFeedbackForAction = useCallback( - (actionRef: string) => { - return verifiedActionFeedbacks.filter( - feedback => feedback.feedbackRaw.runnable_ref === actionRef - ); - }, - [verifiedActionFeedbacks] - ); - console.log(verifiedActionFeedbacks); - const allCallActions: CallActionRow[] = useMemo(() => { - return ( + const actionRunnableRefs = useMemo(() => { + return new Set( actionDefinitions.map(actionDefinition => { - const actionDefinitionRefUri = objectVersionKeyToRefUri({ + return objectVersionKeyToRefUri({ scheme: WEAVE_REF_SCHEME, weaveKind: 'object', entity: props.call.entity, @@ -131,50 +104,97 @@ export const CallActionsViewer: React.FC<{ versionHash: actionDefinition.digest, path: '', }); - const feedbacks = getFeedbackForAction(actionDefinitionRefUri); - const selectedFeedback = - feedbacks.length > 0 ? feedbacks[0] : undefined; - return { - actionRef: actionDefinitionRefUri, - actionDef: actionDefinition.val, - runCount: feedbacks.length, - lastRanAt: selectedFeedback - ? convertISOToDate(selectedFeedback.feedbackRaw.created_at + 'Z') - : undefined, - lastResult: selectedFeedback ? selectedFeedback.data : undefined, - }; - }) ?? [] + }) ); - }, [ - actionDefinitions, - getFeedbackForAction, - props.call.entity, - props.call.project, - ]); + }, [actionDefinitions, props.call.entity, props.call.project]); + + const runnableFeedbacks: Feedback[] = useMemo(() => { + return (feedbackQuery.result ?? []).filter( + f => + f.feedback_type?.startsWith('wandb.runnable') && f.runnable_ref !== null + ); + }, [feedbackQuery.result]); - const columns = [ - {field: 'action', headerName: 'Action', flex: 1}, - {field: 'runCount', headerName: 'Run Count', flex: 1}, + const rows = useMemo(() => { + return _.sortBy( + Object.entries(_.groupBy(runnableFeedbacks, f => f.feedback_type)).map( + ([runnableRef, fs]) => { + const val = _.reverse(_.sortBy(fs, 'created_at'))[0]; + return { + id: val.feedback_type, + feedback: val, + runCount: fs.length, + }; + } + ), + s => s.feedback.feedback_type + ); + }, [runnableFeedbacks]); + + const columns: Array> = [ + { + field: 'scorer', + headerName: 'Scorer', + width: 100, + renderCell: params => { + return params.row.feedback.feedback_type.split('.').pop(); + }, + }, + { + field: 'runnable_ref', + headerName: 'Logic', + width: 60, + renderCell: params => { + return ( + + + + ); + }, + }, + {field: 'runCount', headerName: 'Runs', width: 55}, { field: 'lastResult', headerName: 'Last Result', flex: 1, renderCell: (params: any) => { - const value = params.row.lastResult; + const value = params.row.feedback.payload.output; if (value == null) { return ; } - return ; + return ( + + + + ); }, }, { field: 'lastRanAt', headerName: 'Last Ran At', - flex: 1, + width: 100, renderCell: (params: any) => { - const value = params.row.lastRanAt - ? params.row.lastRanAt.getTime() / 1000 - : undefined; + const createdAt = new Date(params.row.feedback.created_at + 'Z'); + const value = createdAt ? createdAt.getTime() / 1000 : undefined; if (value == null) { return ; } @@ -184,27 +204,20 @@ export const CallActionsViewer: React.FC<{ { field: 'run', headerName: 'Run', - flex: 1, - renderCell: (params: any) => ( - - ), + width: 70, + renderCell: (params: any) => + actionRunnableRefs.has(params.row.feedback.runnable_ref) ? ( + + ) : null, }, ]; - const rows = allCallActions.map((action, index) => ({ - id: index, - action: action.actionDef.name, - runCount: action.runCount, - lastResult: action.lastResult, - lastRanAt: action.lastRanAt, - actionRef: action.actionRef, - })); return ( <> Date: Wed, 6 Nov 2024 17:13:20 -0800 Subject: [PATCH 069/120] Fixed columns --- .../Home/Browse3/pages/CallPage/CallScoresViewer.tsx | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx index 9cd71d08ade..6a25179b4f3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx @@ -135,7 +135,7 @@ export const CallScoresViewer: React.FC<{ { field: 'scorer', headerName: 'Scorer', - width: 100, + width: 150, renderCell: params => { return params.row.feedback.feedback_type.split('.').pop(); }, @@ -178,7 +178,6 @@ export const CallScoresViewer: React.FC<{ sx={{ width: '100%', display: 'flex', - justifyContent: 'center', height: '100%', lineHeight: '20px', alignItems: 'center', @@ -203,7 +202,7 @@ export const CallScoresViewer: React.FC<{ }, { field: 'run', - headerName: 'Run', + headerName: '', width: 70, renderCell: (params: any) => actionRunnableRefs.has(params.row.feedback.runnable_ref) ? ( From db6dd1927379463403afe286e723646fa7777d72 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 6 Nov 2024 20:03:12 -0800 Subject: [PATCH 070/120] comments --- .../litellm/test_actions_lifecycle_llm_judge.py | 4 +--- weave/trace_server/actions_worker/dispatcher.py | 7 +++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/integrations/litellm/test_actions_lifecycle_llm_judge.py b/tests/integrations/litellm/test_actions_lifecycle_llm_judge.py index 59f8c17e260..11100c3110c 100644 --- a/tests/integrations/litellm/test_actions_lifecycle_llm_judge.py +++ b/tests/integrations/litellm/test_actions_lifecycle_llm_judge.py @@ -25,9 +25,7 @@ class DummySecretFetcher: def fetch(self, secret_name: str) -> dict: - return { - "secrets": {secret_name: os.environ.get(secret_name, "DUMMY_SECRET_VALUE")} - } + return {"secrets": {secret_name: os.getenv(secret_name, "DUMMY_SECRET_VALUE")}} primitive_mock_response = { diff --git a/weave/trace_server/actions_worker/dispatcher.py b/weave/trace_server/actions_worker/dispatcher.py index 90f06d02085..74828356ada 100644 --- a/weave/trace_server/actions_worker/dispatcher.py +++ b/weave/trace_server/actions_worker/dispatcher.py @@ -12,7 +12,10 @@ ContainsWordsActionSpec, LlmJudgeActionSpec, ) -from weave.trace_server.interface.feedback_types import RunnablePayloadSchema +from weave.trace_server.interface.feedback_types import ( + RUNNABLE_FEEDBACK_TYPE_PREFIX, + RunnablePayloadSchema, +) from weave.trace_server.refs_internal import ( InternalCallRef, InternalObjectRef, @@ -127,7 +130,7 @@ def publish_results_as_feedback( FeedbackCreateReq( project_id=project_id, weave_ref=weave_ref, - feedback_type="wandb.runnable." + action_name, + feedback_type=RUNNABLE_FEEDBACK_TYPE_PREFIX + "." + action_name, runnable_ref=action_ref, payload=RunnablePayloadSchema(output=result).model_dump(), wb_user_id=wb_user_id, From 052c74b51c669fc7a4d81ebfd26a8ac9881abfb6 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 6 Nov 2024 20:24:45 -0800 Subject: [PATCH 071/120] comments --- tests/trace/test_actions_lifecycle.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/trace/test_actions_lifecycle.py b/tests/trace/test_actions_lifecycle.py index 7ccae08e819..e044a7d403f 100644 --- a/tests/trace/test_actions_lifecycle.py +++ b/tests/trace/test_actions_lifecycle.py @@ -12,6 +12,9 @@ ) +# Note: the word count action spec is super simple +# whereas the llm one requires the full litellm code path. +# The LLM tests are in `test_actions_lifecycle_llm_judge.py`. def test_action_lifecycle_word_count(client: WeaveClient): if client_is_sqlite(client): return pytest.skip("skipping for sqlite") @@ -32,6 +35,7 @@ def test_action_lifecycle_word_count(client: WeaveClient): action_ref_uri = published_ref.uri() # Part 2: Demonstrate manual feedback (this is not user-facing) + # This could be it's own test, but it's convenient to have it here. @weave.op def example_op(input: str) -> str: return input + "!!!" From 8e308613a607f09ca42a794a6b98b5444b38160a Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 6 Nov 2024 21:16:07 -0800 Subject: [PATCH 072/120] some final fixes --- .../pages/CallPage/CallScoresViewer.tsx | 88 ++++++++++++------- 1 file changed, 57 insertions(+), 31 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx index 6a25179b4f3..134dd87f182 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx @@ -93,17 +93,20 @@ export const CallScoresViewer: React.FC<{ ).sort((a, b) => (a.val.name ?? '').localeCompare(b.val.name ?? '')); const actionRunnableRefs = useMemo(() => { - return new Set( + return _.fromPairs( actionDefinitions.map(actionDefinition => { - return objectVersionKeyToRefUri({ - scheme: WEAVE_REF_SCHEME, - weaveKind: 'object', - entity: props.call.entity, - project: props.call.project, - objectId: actionDefinition.object_id, - versionHash: actionDefinition.digest, - path: '', - }); + return [ + 'wandb.runnable.' + actionDefinition.object_id, + objectVersionKeyToRefUri({ + scheme: WEAVE_REF_SCHEME, + weaveKind: 'object', + entity: props.call.entity, + project: props.call.project, + objectId: actionDefinition.object_id, + versionHash: actionDefinition.digest, + path: '', + }), + ]; }) ); }, [actionDefinitions, props.call.entity, props.call.project]); @@ -115,29 +118,41 @@ export const CallScoresViewer: React.FC<{ ); }, [feedbackQuery.result]); - const rows = useMemo(() => { - return _.sortBy( - Object.entries(_.groupBy(runnableFeedbacks, f => f.feedback_type)).map( - ([runnableRef, fs]) => { - const val = _.reverse(_.sortBy(fs, 'created_at'))[0]; - return { - id: val.feedback_type, - feedback: val, - runCount: fs.length, - }; - } - ), - s => s.feedback.feedback_type - ); + const scoredRows = useMemo(() => { + return Object.entries( + _.groupBy(runnableFeedbacks, f => f.feedback_type) + ).map(([runnableRef, fs]) => { + const val = _.reverse(_.sortBy(fs, 'created_at'))[0]; + return { + id: val.feedback_type, + feedback: val, + runCount: fs.length, + }; + }); }, [runnableFeedbacks]); + const rows = useMemo(() => { + const additionalRows = actionDefinitions + .map(actionDefinition => { + return { + id: 'wandb.runnable.' + actionDefinition.object_id, + feedback: null, + runCount: 0, + }; + }) + .filter(row => !scoredRows.some(r => r.id === row.id)); + return _.sortBy([...scoredRows, ...additionalRows], s => s.id); + }, [actionDefinitions, scoredRows]); + + console.log('actionDefinitions', actionDefinitions); + const columns: Array> = [ { field: 'scorer', headerName: 'Scorer', width: 150, renderCell: params => { - return params.row.feedback.feedback_type.split('.').pop(); + return params.row.id.split('.').pop(); }, }, { @@ -145,6 +160,9 @@ export const CallScoresViewer: React.FC<{ headerName: 'Logic', width: 60, renderCell: params => { + if (params.row.feedback == null) { + return null; + } return ( { + renderCell: params => { + if (params.row.feedback == null) { + return null; + } const value = params.row.feedback.payload.output; if (value == null) { return ; @@ -191,7 +212,10 @@ export const CallScoresViewer: React.FC<{ field: 'lastRanAt', headerName: 'Last Ran At', width: 100, - renderCell: (params: any) => { + renderCell: params => { + if (params.row.feedback == null) { + return null; + } const createdAt = new Date(params.row.feedback.created_at + 'Z'); const value = createdAt ? createdAt.getTime() / 1000 : undefined; if (value == null) { @@ -204,16 +228,18 @@ export const CallScoresViewer: React.FC<{ field: 'run', headerName: '', width: 70, - renderCell: (params: any) => - actionRunnableRefs.has(params.row.feedback.runnable_ref) ? ( + renderCell: params => { + const actionRef = actionRunnableRefs[params.row.id]; + return actionRef ? ( - ) : null, + ) : null; + }, }, ]; From 3e8aa1e0ce368a46cc97e12e445efc5314dc281f Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 6 Nov 2024 21:28:30 -0800 Subject: [PATCH 073/120] lint fix --- weave/trace_server/actions_worker/actions/llm_judge.py | 10 +++++----- weave/trace_server/actions_worker/dispatcher.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/weave/trace_server/actions_worker/actions/llm_judge.py b/weave/trace_server/actions_worker/actions/llm_judge.py index f22cd1a1731..5ba23d5bdc4 100644 --- a/weave/trace_server/actions_worker/actions/llm_judge.py +++ b/weave/trace_server/actions_worker/actions/llm_judge.py @@ -44,14 +44,14 @@ def do_llm_judge_action( completion = trace_server.completions_create( CompletionsCreateReq( project_id=project_id, - inputs=dict( - model=model, - messages=[ + inputs={ + "model": model, + "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": json.dumps(args)}, ], - response_format=response_format, - ), + "response_format": response_format, + }, track_llm_call=False, ) ) diff --git a/weave/trace_server/actions_worker/dispatcher.py b/weave/trace_server/actions_worker/dispatcher.py index 74828356ada..9111e8d5026 100644 --- a/weave/trace_server/actions_worker/dispatcher.py +++ b/weave/trace_server/actions_worker/dispatcher.py @@ -60,11 +60,11 @@ def execute_batch( # 1. Lookup the action definition parsed_ref = parse_internal_uri(batch_req.action_ref) if parsed_ref.project_id != project_id: - raise ValueError( + raise TypeError( f"Action ref {batch_req.action_ref} does not match project_id {project_id}" ) if not isinstance(parsed_ref, InternalObjectRef): - raise ValueError(f"Action ref {batch_req.action_ref} is not an object ref") + raise TypeError(f"Action ref {batch_req.action_ref} is not an object ref") action_def_read = trace_server.obj_read( ObjReadReq( @@ -124,7 +124,7 @@ def publish_results_as_feedback( weave_ref = InternalCallRef(project_id, call_id).uri() parsed_action_ref = parse_internal_uri(action_ref) if not isinstance(parsed_action_ref, (InternalObjectRef, InternalOpRef)): - raise ValueError(f"Invalid action ref: {action_ref}") + raise TypeError(f"Invalid action ref: {action_ref}") action_name = parsed_action_ref.name return trace_server.feedback_create( FeedbackCreateReq( From 840010f6599e00e58749ea71a19deed2682aeca9 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 7 Nov 2024 08:56:26 -0800 Subject: [PATCH 074/120] fixed name requirement --- .../Home/Browse3/pages/ScorersPage/ScorersPage.tsx | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx index 319b42efaab..961e740eb80 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx @@ -24,10 +24,8 @@ export const ScorersPage: React.FC<{ title="Scorers" tabs={[ { - // It is true that this panel can show more than LLM Judges, but the - // branding is better - label: 'Configurable Judges', - content: , + label: 'Programmatic Scorers', + content: , }, // This is a placeholder for Griffin's annotation column manager section // { @@ -35,8 +33,10 @@ export const ScorersPage: React.FC<{ // content: , // }, { - label: 'Programmatic Scorers', - content: , + // It is true that this panel can show more than LLM Judges, but the + // branding is better + label: 'Configurable Judges', + content: , }, ]} headerContent={undefined} From e49112d494397ee623cf84dfc6445d6a71e4296e Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 7 Nov 2024 17:40:28 -0800 Subject: [PATCH 075/120] small fix --- .../Home/Browse3/pages/CallPage/CallScoresViewer.tsx | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx index 134dd87f182..61ceeb0e407 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx @@ -20,6 +20,9 @@ import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; import {objectVersionKeyToRefUri} from '../wfReactInterface/utilities'; import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; + +const RUNNABLE_REF_PREFIX = 'wandb.runnable'; + // New RunButton component const RunButton: React.FC<{ actionRef: string; @@ -96,7 +99,7 @@ export const CallScoresViewer: React.FC<{ return _.fromPairs( actionDefinitions.map(actionDefinition => { return [ - 'wandb.runnable.' + actionDefinition.object_id, + RUNNABLE_REF_PREFIX + '.' + actionDefinition.object_id, objectVersionKeyToRefUri({ scheme: WEAVE_REF_SCHEME, weaveKind: 'object', @@ -114,7 +117,7 @@ export const CallScoresViewer: React.FC<{ const runnableFeedbacks: Feedback[] = useMemo(() => { return (feedbackQuery.result ?? []).filter( f => - f.feedback_type?.startsWith('wandb.runnable') && f.runnable_ref !== null + f.feedback_type?.startsWith(RUNNABLE_REF_PREFIX) && f.runnable_ref !== null ); }, [feedbackQuery.result]); @@ -135,7 +138,7 @@ export const CallScoresViewer: React.FC<{ const additionalRows = actionDefinitions .map(actionDefinition => { return { - id: 'wandb.runnable.' + actionDefinition.object_id, + id: RUNNABLE_REF_PREFIX + '.' + actionDefinition.object_id, feedback: null, runCount: 0, }; @@ -152,7 +155,7 @@ export const CallScoresViewer: React.FC<{ headerName: 'Scorer', width: 150, renderCell: params => { - return params.row.id.split('.').pop(); + return params.row.id.slice(RUNNABLE_REF_PREFIX.length + 1); }, }, { From 57c1e3f3e4e5f5dbd706298527ad35f579ac72c7 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 7 Nov 2024 18:36:25 -0800 Subject: [PATCH 076/120] pickup new change --- .../pages/CallPage/CallScoresViewer.tsx | 28 ++++++++--------- .../Home/Browse3/pages/ObjectVersionPage.tsx | 2 +- .../Home/Browse3/pages/ObjectVersionsPage.tsx | 2 +- ...nitionModal.tsx => NewActionSpecModal.tsx} | 30 +++++++++---------- .../Browse3/pages/ScorersPage/ScorersPage.tsx | 21 +++++++------ ...onSpecs.ts => actionSpecConfigurations.ts} | 8 ++--- .../pages/common/TypeVersionCategoryChip.tsx | 2 +- .../pages/wfReactInterface/constants.ts | 2 +- 8 files changed, 46 insertions(+), 49 deletions(-) rename weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/{NewActionDefinitionModal.tsx => NewActionSpecModal.tsx} (72%) rename weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/{actionDefinitionConfigurationSpecs.ts => actionSpecConfigurations.ts} (93%) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx index 61ceeb0e407..c0ab65103ee 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx @@ -20,7 +20,6 @@ import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; import {objectVersionKeyToRefUri} from '../wfReactInterface/utilities'; import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; - const RUNNABLE_REF_PREFIX = 'wandb.runnable'; // New RunButton component @@ -85,8 +84,8 @@ export const CallScoresViewer: React.FC<{ weaveRef, }); - const actionDefinitions = ( - useBaseObjectInstances('ActionDefinition', { + const actionSpecs = ( + useBaseObjectInstances('ActionSpec', { project_id: projectIdFromParts({ entity: props.call.entity, project: props.call.project, @@ -97,27 +96,28 @@ export const CallScoresViewer: React.FC<{ const actionRunnableRefs = useMemo(() => { return _.fromPairs( - actionDefinitions.map(actionDefinition => { + actionSpecs.map(actionSpec => { return [ - RUNNABLE_REF_PREFIX + '.' + actionDefinition.object_id, + RUNNABLE_REF_PREFIX + '.' + actionSpec.object_id, objectVersionKeyToRefUri({ scheme: WEAVE_REF_SCHEME, weaveKind: 'object', entity: props.call.entity, project: props.call.project, - objectId: actionDefinition.object_id, - versionHash: actionDefinition.digest, + objectId: actionSpec.object_id, + versionHash: actionSpec.digest, path: '', }), ]; }) ); - }, [actionDefinitions, props.call.entity, props.call.project]); + }, [actionSpecs, props.call.entity, props.call.project]); const runnableFeedbacks: Feedback[] = useMemo(() => { return (feedbackQuery.result ?? []).filter( f => - f.feedback_type?.startsWith(RUNNABLE_REF_PREFIX) && f.runnable_ref !== null + f.feedback_type?.startsWith(RUNNABLE_REF_PREFIX) && + f.runnable_ref !== null ); }, [feedbackQuery.result]); @@ -135,19 +135,19 @@ export const CallScoresViewer: React.FC<{ }, [runnableFeedbacks]); const rows = useMemo(() => { - const additionalRows = actionDefinitions - .map(actionDefinition => { + const additionalRows = actionSpecs + .map(actionSpec => { return { - id: RUNNABLE_REF_PREFIX + '.' + actionDefinition.object_id, + id: RUNNABLE_REF_PREFIX + '.' + actionSpec.object_id, feedback: null, runCount: 0, }; }) .filter(row => !scoredRows.some(r => r.id === row.id)); return _.sortBy([...scoredRows, ...additionalRows], s => s.id); - }, [actionDefinitions, scoredRows]); + }, [actionSpecs, scoredRows]); - console.log('actionDefinitions', actionDefinitions); + console.log('actionSpecs', actionSpecs); const columns: Array> = [ { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx index f60cf7c7007..cc119784233 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx @@ -54,7 +54,7 @@ const OBJECT_ICONS: Record = { Evaluation: 'baseline-alt', Leaderboard: 'benchmark-square', Scorer: 'type-number-alt', - ActionDefinition: 'rocket-launch', + ActionSpec: 'rocket-launch', }; const ObjectIcon = ({baseObjectClass}: ObjectIconProps) => { if (baseObjectClass in OBJECT_ICONS) { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx index c6b6353caf1..375854ab34c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx @@ -171,7 +171,7 @@ export const FilterableObjectVersionsTable: React.FC<{ propsEmpty = EMPTY_PROPS_LEADERBOARDS; } else if (base === 'Scorer') { propsEmpty = EMPTY_PROPS_PROGRAMATIC_SCORERS; - } else if (base === 'ActionDefinition') { + } else if (base === 'ActionSpec') { propsEmpty = EMPTY_PROPS_ACTION_DEFINITIONS; } return ; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionDefinitionModal.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionSpecModal.tsx similarity index 72% rename from weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionDefinitionModal.tsx rename to weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionSpecModal.tsx index 8225abe9907..cbc42e4e65f 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionDefinitionModal.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionSpecModal.tsx @@ -10,23 +10,23 @@ import React, {FC, useEffect, useState} from 'react'; import {DynamicConfigForm} from '../../DynamicConfigForm'; import {ReusableDrawer} from '../../ReusableDrawer'; import { - ActionDefinition, - ActionDefinitionSchema, + ActionSpec, + ActionSpecSchema, ActionType, } from '../wfReactInterface/generatedBaseObjectClasses.zod'; -import {actionDefinitionConfigurationSpecs} from './actionDefinitionConfigurationSpecs'; +import {actionSpecConfigurationSpecs} from './actionSpecConfigurations'; -interface NewActionDefinitionModalProps { +interface NewActionSpecModalProps { open: boolean; onClose: () => void; - onSave: (newAction: ActionDefinition) => void; + onSave: (newAction: ActionSpec) => void; initialTemplate?: { actionType: ActionType; template: {name: string; config: Record}; } | null; } -export const NewActionDefinitionModal: FC = ({ +export const NewActionSpecModal: FC = ({ open, onClose, onSave, @@ -36,8 +36,8 @@ export const NewActionDefinitionModal: FC = ({ const [selectedActionType, setSelectedActionType] = useState('llm_judge'); const [config, setConfig] = useState>({}); - const selectedActionDefinitionConfigurationSpec = - actionDefinitionConfigurationSpecs[selectedActionType]; + const selectedActionSpecConfigurationSpec = + actionSpecConfigurationSpecs[selectedActionType]; useEffect(() => { if (initialTemplate) { @@ -51,12 +51,12 @@ export const NewActionDefinitionModal: FC = ({ }, [initialTemplate]); const handleSave = () => { - if (!selectedActionDefinitionConfigurationSpec) { + if (!selectedActionSpecConfigurationSpec) { return; } - const newAction = ActionDefinitionSchema.parse({ + const newAction = ActionSpecSchema.parse({ name, - spec: selectedActionDefinitionConfigurationSpec.convert(config as any), + config: selectedActionSpecConfigurationSpec.convert(config as any), }); onSave(newAction); setConfig({}); @@ -85,7 +85,7 @@ export const NewActionDefinitionModal: FC = ({ - {selectedActionDefinitionConfigurationSpec && ( + {selectedActionSpecConfigurationSpec && ( , + content: , }, ]} headerContent={undefined} @@ -59,7 +59,7 @@ const CodeScorersTab: React.FC<{ ); }; -const ActionDefinitionsTab: React.FC<{ +const ActionSpecsTab: React.FC<{ entity: string; project: string; }> = ({entity, project}) => { @@ -68,8 +68,7 @@ const ActionDefinitionsTab: React.FC<{ actionType: ActionType; template: {name: string; config: Record}; } | null>(null); - const createCollectionObject = - useCreateBaseObjectInstance('ActionDefinition'); + const createCollectionObject = useCreateBaseObjectInstance('ActionSpec'); const [lastUpdatedTimestamp, setLastUpdatedTimestamp] = useState(0); const [anchorEl, setAnchorEl] = React.useState(null); @@ -102,7 +101,7 @@ const ActionDefinitionsTab: React.FC<{ setSelectedTemplate(null); }; - const handleSaveModal = (newAction: ActionDefinition) => { + const handleSaveModal = (newAction: ActionSpec) => { let objectId = newAction.name; // Remove non alphanumeric characters // TODO: reconcile this null-name issue @@ -159,7 +158,7 @@ const ActionDefinitionsTab: React.FC<{ /> - {Object.entries(actionDefinitionConfigurationSpecs).flatMap( + {Object.entries(actionSpecConfigurationSpecs).flatMap( ([actionType, spec]) => spec.templates.map(template => ( - = { @@ -55,8 +55,8 @@ type actionDefinitionConfigurationSpec< }>; }; -export const actionDefinitionConfigurationSpecs: Partial< - Record +export const actionSpecConfigurationSpecs: Partial< + Record > = { llm_judge: { name: 'LLM Judge', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx index 12527a52c00..7408f0357b3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx @@ -10,7 +10,7 @@ const colorMap: Record = { Evaluation: 'cactus', Leaderboard: 'gold', Scorer: 'purple', - ActionDefinition: 'sienna', + ActionSpec: 'sienna', }; export const TypeVersionCategoryChip: React.FC<{ diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts index bf54110c369..60f97b2d28d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts @@ -26,5 +26,5 @@ export const KNOWN_BASE_OBJECT_CLASSES = [ 'Evaluation', 'Leaderboard', 'Scorer', - 'ActionDefinition', + 'ActionSpec', ] as const; From 2cd872b98768ddf1eeb8af92ec370ab4c8a858ed Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 12 Nov 2024 17:28:10 -0800 Subject: [PATCH 077/120] lint --- .../Browse3/pages/ScorersPage/ScorersPage.tsx | 31 +++++++++---------- .../Browse3/pages/common/EmptyContent.tsx | 2 -- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx index a6ea95012d8..c894a5f499a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx @@ -1,25 +1,24 @@ - -import {Box} from '@material-ui/core'; -import Menu from '@mui/material/Menu'; -import MenuItem from '@mui/material/MenuItem'; -import {Button} from '@wandb/weave/components/Button/Button'; -import React, {FC, useState} from 'react'; +// import {Box} from '@material-ui/core'; +// import Menu from '@mui/material/Menu'; +// import MenuItem from '@mui/material/MenuItem'; +// import {Button} from '@wandb/weave/components/Button/Button'; +// import React, {FC, useState} from 'react'; import React from 'react'; +// import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; -import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; -import {FilterableObjectVersionsTable} from '../ObjectVersionsPage'; -import {useCreateBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; -import { - ActionSpec, - ActionType, -} from '../wfReactInterface/generatedBaseObjectClasses.zod'; -import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; -import {actionSpecConfigurationSpecs} from './actionSpecConfigurations'; +// import {FilterableObjectVersionsTable} from '../ObjectVersionsPage'; +// import {useCreateBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; +// import { +// ActionSpec, +// ActionType, +// } from '../wfReactInterface/generatedBaseObjectClasses.zod'; +// import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; +// import {actionSpecConfigurationSpecs} from './actionSpecConfigurations'; import {ActionSpecsTab} from './ActionSpecsTab'; import {AnnotationsTab} from './AnnotationsTab'; import {ProgrammaticScorersTab} from './CoreScorersTab'; -import {NewActionSpecModal} from './NewActionSpecModal'; +// import {NewActionSpecModal} from './NewActionSpecModal'; export const ScorersPage: React.FC<{ entity: string; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx index 43870ea7d44..cab77196a98 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/EmptyContent.tsx @@ -184,7 +184,6 @@ export const EMPTY_NO_TRACE_SERVER: EmptyProps = { ), }; - export const EMPTY_PROPS_PROGRAMMATIC_SCORERS: EmptyProps = { icon: 'type-number-alt' as const, heading: 'No programmatic scorers yet', @@ -200,7 +199,6 @@ export const EMPTY_PROPS_PROGRAMMATIC_SCORERS: EmptyProps = { ), }; - export const EMPTY_PROPS_ACTION_DEFINITIONS: EmptyProps = { icon: 'automation-robot-arm' as const, heading: 'No configurations yet', From 8daa7ca07536d6677b328df61b3968da2b238075 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 12 Nov 2024 17:51:02 -0800 Subject: [PATCH 078/120] lint --- .../components/FancyPage/useProjectSidebar.ts | 7 - .../pages/ScorersPage/ActionSpecsTab.tsx | 183 ++++++++++++++++- .../Browse3/pages/ScorersPage/ScorersPage.tsx | 190 ------------------ 3 files changed, 182 insertions(+), 198 deletions(-) diff --git a/weave-js/src/components/FancyPage/useProjectSidebar.ts b/weave-js/src/components/FancyPage/useProjectSidebar.ts index 0a24cfdd1c7..92d8d7a7a50 100644 --- a/weave-js/src/components/FancyPage/useProjectSidebar.ts +++ b/weave-js/src/components/FancyPage/useProjectSidebar.ts @@ -165,13 +165,6 @@ export const useProjectSidebar = ( // isShown: isWeaveOnly, // iconName: IconNames.TypeNumberAlt, // }, - { - type: 'button' as const, - name: 'Scorers', - slug: 'weave/scorers', - isShown: isWeaveOnly, - iconName: IconNames.TypeNumberAlt, - }, { type: 'divider' as const, key: 'dividerWithinWeave-2', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ActionSpecsTab.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ActionSpecsTab.tsx index 25b6c4d76aa..8638d42dbb3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ActionSpecsTab.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ActionSpecsTab.tsx @@ -1,8 +1,189 @@ +// import {Box} from '@material-ui/core'; +// import Menu from '@mui/material/Menu'; +// import MenuItem from '@mui/material/MenuItem'; +// import {Button} from '@wandb/weave/components/Button/Button'; import React from 'react'; +// import React, {FC, useState} from 'react'; +import {FilterableObjectVersionsTable} from '../ObjectVersionsPage'; +// import {useCreateBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; +// import { +// ActionSpec, +// ActionType, +// } from '../wfReactInterface/generatedBaseObjectClasses.zod'; +// import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; +// import {actionSpecConfigurationSpecs} from './actionSpecConfigurations'; +// import {NewActionSpecModal} from './NewActionSpecModal'; + export const ActionSpecsTab: React.FC<{ entity: string; project: string; }> = ({entity, project}) => { - return <>Coming Soon - Configurable Judges; + return ( + + ); }; + +// const ActionSpecsTab: React.FC<{ +// entity: string; +// project: string; +// }> = ({entity, project}) => { +// const [isModalOpen, setIsModalOpen] = useState(false); +// const [selectedTemplate, setSelectedTemplate] = useState<{ +// actionType: ActionType; +// template: {name: string; config: Record}; +// } | null>(null); +// const createCollectionObject = useCreateBaseObjectInstance('ActionSpec'); +// const [lastUpdatedTimestamp, setLastUpdatedTimestamp] = useState(0); + +// const [anchorEl, setAnchorEl] = React.useState(null); +// const open = Boolean(anchorEl); + +// const handleCreateBlank = () => { +// setSelectedTemplate(null); +// setIsModalOpen(true); +// }; + +// const handleDropdownClick = (event: React.MouseEvent) => { +// setAnchorEl(event.currentTarget); +// }; + +// const handleClose = () => { +// setAnchorEl(null); +// }; + +// const handleTemplateSelect = (template: { +// actionType: ActionType; +// template: {name: string; config: Record}; +// }) => { +// setSelectedTemplate(template); +// setIsModalOpen(true); +// handleClose(); +// }; + +// const handleCloseModal = () => { +// setIsModalOpen(false); +// setSelectedTemplate(null); +// }; + +// const handleSaveModal = (newAction: ActionSpec) => { +// let objectId = newAction.name; +// // Remove non alphanumeric characters +// // TODO: reconcile this null-name issue +// objectId = objectId?.replace(/[^a-zA-Z0-9]/g, '-') ?? ''; +// createCollectionObject({ +// obj: { +// project_id: projectIdFromParts({entity, project}), +// object_id: objectId, +// val: newAction, +// }, +// }) +// .then(() => { +// setLastUpdatedTimestamp(Date.now()); +// }) +// .catch(err => { +// console.error(err); +// }) +// .finally(() => { +// handleCloseModal(); +// }); +// }; + +// return ( +// +// +// +// +// +// +// ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx index c894a5f499a..504fde151dd 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx @@ -1,24 +1,9 @@ -// import {Box} from '@material-ui/core'; -// import Menu from '@mui/material/Menu'; -// import MenuItem from '@mui/material/MenuItem'; -// import {Button} from '@wandb/weave/components/Button/Button'; -// import React, {FC, useState} from 'react'; import React from 'react'; -// import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; -// import {FilterableObjectVersionsTable} from '../ObjectVersionsPage'; -// import {useCreateBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; -// import { -// ActionSpec, -// ActionType, -// } from '../wfReactInterface/generatedBaseObjectClasses.zod'; -// import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; -// import {actionSpecConfigurationSpecs} from './actionSpecConfigurations'; import {ActionSpecsTab} from './ActionSpecsTab'; import {AnnotationsTab} from './AnnotationsTab'; import {ProgrammaticScorersTab} from './CoreScorersTab'; -// import {NewActionSpecModal} from './NewActionSpecModal'; export const ScorersPage: React.FC<{ entity: string; @@ -45,178 +30,3 @@ export const ScorersPage: React.FC<{ /> ); }; -// <<<<<<< HEAD -// const CodeScorersTab: React.FC<{ -// entity: string; -// project: string; -// }> = ({entity, project}) => { -// return ( -// -// ); -// }; - -// const ActionSpecsTab: React.FC<{ -// entity: string; -// project: string; -// }> = ({entity, project}) => { -// const [isModalOpen, setIsModalOpen] = useState(false); -// const [selectedTemplate, setSelectedTemplate] = useState<{ -// actionType: ActionType; -// template: {name: string; config: Record}; -// } | null>(null); -// const createCollectionObject = useCreateBaseObjectInstance('ActionSpec'); -// const [lastUpdatedTimestamp, setLastUpdatedTimestamp] = useState(0); - -// const [anchorEl, setAnchorEl] = React.useState(null); -// const open = Boolean(anchorEl); - -// const handleCreateBlank = () => { -// setSelectedTemplate(null); -// setIsModalOpen(true); -// }; - -// const handleDropdownClick = (event: React.MouseEvent) => { -// setAnchorEl(event.currentTarget); -// }; - -// const handleClose = () => { -// setAnchorEl(null); -// }; - -// const handleTemplateSelect = (template: { -// actionType: ActionType; -// template: {name: string; config: Record}; -// }) => { -// setSelectedTemplate(template); -// setIsModalOpen(true); -// handleClose(); -// }; - -// const handleCloseModal = () => { -// setIsModalOpen(false); -// setSelectedTemplate(null); -// }; - -// const handleSaveModal = (newAction: ActionSpec) => { -// let objectId = newAction.name; -// // Remove non alphanumeric characters -// // TODO: reconcile this null-name issue -// objectId = objectId?.replace(/[^a-zA-Z0-9]/g, '-') ?? ''; -// createCollectionObject({ -// obj: { -// project_id: projectIdFromParts({entity, project}), -// object_id: objectId, -// val: newAction, -// }, -// }) -// .then(() => { -// setLastUpdatedTimestamp(Date.now()); -// }) -// .catch(err => { -// console.error(err); -// }) -// .finally(() => { -// handleCloseModal(); -// }); -// }; - -// return ( -// -// -// -// -// -// -// ); -// ======= -// >>>>>>> master From 3ffac60176192aca6db0b853fed5b34ecd9cbab0 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 12 Nov 2024 18:05:22 -0800 Subject: [PATCH 079/120] done --- .../pages/ScorersPage/ActionSpecsTab.tsx | 330 +++++++++--------- 1 file changed, 157 insertions(+), 173 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ActionSpecsTab.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ActionSpecsTab.tsx index 8638d42dbb3..a4eae5402b2 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ActionSpecsTab.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ActionSpecsTab.tsx @@ -1,189 +1,173 @@ -// import {Box} from '@material-ui/core'; -// import Menu from '@mui/material/Menu'; -// import MenuItem from '@mui/material/MenuItem'; -// import {Button} from '@wandb/weave/components/Button/Button'; -import React from 'react'; +import {Box} from '@material-ui/core'; +import Menu from '@mui/material/Menu'; +import MenuItem from '@mui/material/MenuItem'; +import {Button} from '@wandb/weave/components/Button/Button'; +import React, {FC, useState} from 'react'; -// import React, {FC, useState} from 'react'; import {FilterableObjectVersionsTable} from '../ObjectVersionsPage'; -// import {useCreateBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; -// import { -// ActionSpec, -// ActionType, -// } from '../wfReactInterface/generatedBaseObjectClasses.zod'; -// import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; -// import {actionSpecConfigurationSpecs} from './actionSpecConfigurations'; -// import {NewActionSpecModal} from './NewActionSpecModal'; +import {useCreateBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; +import { + ActionSpec, + ActionType, +} from '../wfReactInterface/generatedBaseObjectClasses.zod'; +import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; +import {actionSpecConfigurationSpecs} from './actionSpecConfigurations'; +import {NewActionSpecModal} from './NewActionSpecModal'; export const ActionSpecsTab: React.FC<{ entity: string; project: string; }> = ({entity, project}) => { - return ( - - ); -}; - -// const ActionSpecsTab: React.FC<{ -// entity: string; -// project: string; -// }> = ({entity, project}) => { -// const [isModalOpen, setIsModalOpen] = useState(false); -// const [selectedTemplate, setSelectedTemplate] = useState<{ -// actionType: ActionType; -// template: {name: string; config: Record}; -// } | null>(null); -// const createCollectionObject = useCreateBaseObjectInstance('ActionSpec'); -// const [lastUpdatedTimestamp, setLastUpdatedTimestamp] = useState(0); + const [isModalOpen, setIsModalOpen] = useState(false); + const [selectedTemplate, setSelectedTemplate] = useState<{ + actionType: ActionType; + template: {name: string; config: Record}; + } | null>(null); + const createCollectionObject = useCreateBaseObjectInstance('ActionSpec'); + const [lastUpdatedTimestamp, setLastUpdatedTimestamp] = useState(0); -// const [anchorEl, setAnchorEl] = React.useState(null); -// const open = Boolean(anchorEl); + const [anchorEl, setAnchorEl] = React.useState(null); + const open = Boolean(anchorEl); -// const handleCreateBlank = () => { -// setSelectedTemplate(null); -// setIsModalOpen(true); -// }; + const handleCreateBlank = () => { + setSelectedTemplate(null); + setIsModalOpen(true); + }; -// const handleDropdownClick = (event: React.MouseEvent) => { -// setAnchorEl(event.currentTarget); -// }; + const handleDropdownClick = (event: React.MouseEvent) => { + setAnchorEl(event.currentTarget); + }; -// const handleClose = () => { -// setAnchorEl(null); -// }; + const handleClose = () => { + setAnchorEl(null); + }; -// const handleTemplateSelect = (template: { -// actionType: ActionType; -// template: {name: string; config: Record}; -// }) => { -// setSelectedTemplate(template); -// setIsModalOpen(true); -// handleClose(); -// }; + const handleTemplateSelect = (template: { + actionType: ActionType; + template: {name: string; config: Record}; + }) => { + setSelectedTemplate(template); + setIsModalOpen(true); + handleClose(); + }; -// const handleCloseModal = () => { -// setIsModalOpen(false); -// setSelectedTemplate(null); -// }; + const handleCloseModal = () => { + setIsModalOpen(false); + setSelectedTemplate(null); + }; -// const handleSaveModal = (newAction: ActionSpec) => { -// let objectId = newAction.name; -// // Remove non alphanumeric characters -// // TODO: reconcile this null-name issue -// objectId = objectId?.replace(/[^a-zA-Z0-9]/g, '-') ?? ''; -// createCollectionObject({ -// obj: { -// project_id: projectIdFromParts({entity, project}), -// object_id: objectId, -// val: newAction, -// }, -// }) -// .then(() => { -// setLastUpdatedTimestamp(Date.now()); -// }) -// .catch(err => { -// console.error(err); -// }) -// .finally(() => { -// handleCloseModal(); -// }); -// }; + const handleSaveModal = (newAction: ActionSpec) => { + let objectId = newAction.name; + // Remove non alphanumeric characters + // TODO: reconcile this null-name issue + objectId = objectId?.replace(/[^a-zA-Z0-9]/g, '-') ?? ''; + createCollectionObject({ + obj: { + project_id: projectIdFromParts({entity, project}), + object_id: objectId, + val: newAction, + }, + }) + .then(() => { + setLastUpdatedTimestamp(Date.now()); + }) + .catch(err => { + console.error(err); + }) + .finally(() => { + handleCloseModal(); + }); + }; -// return ( -// -// -// -// -// + -// -// ); +export const AddNewButton: FC<{ + onClick: () => void; + disabled?: boolean; + tooltipText?: string; +}> = ({onClick, disabled, tooltipText}) => ( + + + +); From 3a3b70c1f7d46e6057452eb935778784048e45ba Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 12 Nov 2024 18:06:50 -0800 Subject: [PATCH 080/120] done --- .../Browse3/pages/wfReactInterface/traceServerClientTypes.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts index 6e5fd393ef7..6c6bcaf4581 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts @@ -171,7 +171,10 @@ export type Feedback = { created_at: string; feedback_type: string; payload: Record; + annotation_ref?: string; runnable_ref?: string; + call_ref?: string; + trigger_ref?: string; }; export type FeedbackQuerySuccess = { From bb5697ecdb0e9df32785e33d6d191d5cb3cb0a57 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 12 Nov 2024 18:24:23 -0800 Subject: [PATCH 081/120] remove dead code --- .../pages/OpVersionPage/OpVersionPage.tsx | 152 ------------------ 1 file changed, 152 deletions(-) delete mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx deleted file mode 100644 index 9e14df7152c..00000000000 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage/OpVersionPage.tsx +++ /dev/null @@ -1,152 +0,0 @@ -import React, {useMemo} from 'react'; - -import {LoadingDots} from '../../../../../LoadingDots'; -import {Tailwind} from '../../../../../Tailwind'; -import {NotFoundPanel} from '../../NotFoundPanel'; -import {OpCodeViewer} from '../../OpCodeViewer'; -import { - CallsLink, - opNiceName, - OpVersionsLink, - opVersionText, -} from '../common/Links'; -import {CenteredAnimatedLoader} from '../common/Loader'; -import { - ScrollableTabContent, - SimpleKeyValueTable, - SimplePageLayoutWithHeader, -} from '../common/SimplePageLayout'; -import {TabUseOp} from '../TabUseOp'; -import {useWFHooks} from '../wfReactInterface/context'; -import {opVersionKeyToRefUri} from '../wfReactInterface/utilities'; -import {OpVersionSchema} from '../wfReactInterface/wfDataModelHooksInterface'; - -export const OpVersionPage: React.FC<{ - entity: string; - project: string; - opName: string; - version: string; -}> = props => { - const {useOpVersion} = useWFHooks(); - - const opVersion = useOpVersion({ - entity: props.entity, - project: props.project, - opId: props.opName, - versionHash: props.version, - }); - if (opVersion.loading) { - return ; - } else if (opVersion.result == null) { - return ; - } - return ; -}; - -const OpVersionPageInner: React.FC<{ - opVersion: OpVersionSchema; -}> = ({opVersion}) => { - const {useOpVersions, useCallsStats} = useWFHooks(); - const uri = opVersionKeyToRefUri(opVersion); - const {entity, project, opId, versionIndex} = opVersion; - - const opVersions = useOpVersions( - entity, - project, - { - opIds: [opId], - }, - undefined, // limit - true // metadataOnly - ); - const opVersionCount = (opVersions.result ?? []).length; - const callsStats = useCallsStats(entity, project, { - opVersionRefs: [uri], - }); - const opVersionCallCount = callsStats?.result?.count ?? 0; - const useOpSupported = useMemo(() => { - // TODO: We really want to return `True` only when - // the op is not a bound op. However, we don't have - // that data available yet. - return true; - }, []); - - return ( - - {opId}{' '} - {opVersions.loading ? ( - - ) : ( - <> - [ - - ] - - )} - - ), - Version: <>{versionIndex}, - Calls: - !callsStats.loading || opVersionCallCount > 0 ? ( - - ) : ( - <> - ), - }} - /> - } - tabs={[ - { - label: 'Code', - content: ( - - ), - }, - ...(useOpSupported - ? [ - { - label: 'Use', - content: ( - - - - - - ), - }, - ] - : []), - ]} - /> - ); -}; From eeb783c66f1089aed77e03979ff79baef7389d07 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 12 Nov 2024 18:27:34 -0800 Subject: [PATCH 082/120] remove dead code --- weave-js/src/components/PagePanelComponents/Home/Browse3.tsx | 2 +- .../Home/Browse3/pages/common/SimplePageLayout.tsx | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index d18b961172c..16140868a1a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -92,7 +92,7 @@ import { } from './Browse3/pages/ObjectVersionsPage'; import {OpPage} from './Browse3/pages/OpPage'; import {OpsPage} from './Browse3/pages/OpsPage'; -import {OpVersionPage} from './Browse3/pages/OpVersionPage/OpVersionPage'; +import {OpVersionPage} from './Browse3/pages/OpVersionPage'; import {OpVersionsPage} from './Browse3/pages/OpVersionsPage'; import {PlaygroundPage} from './Browse3/pages/PlaygroundPage/PlaygroundPage'; import {ScorersPage} from './Browse3/pages/ScorersPage/ScorersPage'; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx index fb9756fc7cc..8099c1fe98b 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx @@ -347,6 +347,7 @@ export const SimpleKeyValueTable: FC<{ display: 'flex', flexDirection: 'column', justifyContent: 'flex-start', + width: 100, }}> {key} From d694d520fc7f84798b0bbb47cd96b7c8708dfc58 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 12 Nov 2024 18:52:30 -0800 Subject: [PATCH 083/120] removed line change --- .../Browse3/pages/common/SimplePageLayout.tsx | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx index 8099c1fe98b..e4f08956eb9 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx @@ -272,23 +272,21 @@ export const SimplePageLayoutWithHeader: FC<{ {props.headerContent} {(!props.hideTabsIfSingle || tabs.length > 1) && ( - - - - {tabs.map(tab => ( - - {tab.label} - - ))} - - - + + + {tabs.map(tab => ( + + {tab.label} + + ))} + + )} Date: Tue, 12 Nov 2024 18:57:15 -0800 Subject: [PATCH 084/120] revert small ref --- .../PagePanelComponents/Home/Browse2/SmallRef.tsx | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse2/SmallRef.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse2/SmallRef.tsx index 832a9c6b39e..e6f4c0955a8 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse2/SmallRef.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse2/SmallRef.tsx @@ -80,8 +80,7 @@ export const SmallRef: FC<{ objRef: ObjectRef; wfTable?: WFDBTableType; iconOnly?: boolean; - nameOnly?: boolean; -}> = ({objRef, wfTable, iconOnly = false, nameOnly = false}) => { +}> = ({objRef, wfTable, iconOnly = false}) => { const { useObjectVersion, useOpVersion, @@ -141,9 +140,7 @@ export const SmallRef: FC<{ // TODO: Why is this necessary? The type is coming back as `objRef` rootType = {type: 'OpDef'}; } - const {label} = nameOnly - ? {label: objRef.artifactName} - : objectRefDisplayName(objRef, versionIndex); + const {label} = objectRefDisplayName(objRef, versionIndex); const rootTypeName = getTypeName(rootType); let icon: IconName = IconNames.CubeContainer; From 1a125a723addb3221845396312689d84b8c3d92d Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 12 Nov 2024 19:14:08 -0800 Subject: [PATCH 085/120] Pre-refactor --- .../Home/Browse3/DynamicConfigForm.tsx | 122 ++++++++++++------ .../ScorersPage/actionSpecConfigurations.ts | 5 +- 2 files changed, 83 insertions(+), 44 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/DynamicConfigForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/DynamicConfigForm.tsx index e5561c77f9c..1e96bbb78cd 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/DynamicConfigForm.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/DynamicConfigForm.tsx @@ -9,9 +9,10 @@ import { MenuItem, Select, TextField, + Tooltip, Typography, } from '@material-ui/core'; -import {Delete} from '@mui/icons-material'; +import {Delete, Help} from '@mui/icons-material'; import React, {useEffect, useMemo, useState} from 'react'; import {z} from 'zod'; @@ -263,16 +264,19 @@ const NestedForm: React.FC<{ } return ( - - updateConfig(currentPath, e.target.value, config, setConfig) - } - margin="dense" - /> + + + updateConfig(currentPath, e.target.value, config, setConfig) + } + margin="dense" + /> + + ); }; @@ -399,11 +403,14 @@ const EnumField: React.FC<{ return ( - {keyName !== '' ? ( - {keyName} - ) : ( -
- )} + + {keyName !== '' ? ( + {keyName} + ) : ( +
+ )} + + - { @@ -161,7 +161,7 @@ const NestedForm: React.FC<{ {keyName} - } config={config} setConfig={setConfig} @@ -814,7 +814,7 @@ const DescriptionTooltip: React.FC<{description?: string}> = ({description}) => ); }; -export const DynamicConfigForm: React.FC = ({ +export const ZSForm: React.FC = ({ configSchema, config, setConfig, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionSpecModal.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionSpecModal.tsx index cbc42e4e65f..621cd6d1bf7 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionSpecModal.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewActionSpecModal.tsx @@ -7,8 +7,8 @@ import { } from '@material-ui/core'; import React, {FC, useEffect, useState} from 'react'; -import {DynamicConfigForm} from '../../DynamicConfigForm'; import {ReusableDrawer} from '../../ReusableDrawer'; +import {ZSForm} from '../../ZodSchemaForm/ZodSchemaForm'; import { ActionSpec, ActionSpecSchema, @@ -95,7 +95,7 @@ export const NewActionSpecModal: FC = ({ {selectedActionSpecConfigurationSpec && ( - Date: Wed, 13 Nov 2024 15:01:33 -0800 Subject: [PATCH 087/120] Adopting fields --- .../Browse3/ZodSchemaForm/ZodSchemaForm.tsx | 84 +++++++++---------- .../Home/Browse3/filters/SelectField.tsx | 3 +- .../pages/ScorersPage/NewActionSpecModal.tsx | 16 ++-- .../ScorersPage/actionSpecConfigurations.ts | 41 +++++---- 4 files changed, 77 insertions(+), 67 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/ZodSchemaForm/ZodSchemaForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/ZodSchemaForm/ZodSchemaForm.tsx index c2bb8cc9329..e9e27dc7202 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/ZodSchemaForm/ZodSchemaForm.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/ZodSchemaForm/ZodSchemaForm.tsx @@ -6,8 +6,6 @@ import { FormControlLabel, IconButton, InputLabel, - MenuItem, - Select, TextField, Tooltip, Typography, @@ -16,7 +14,7 @@ import {Delete, Help} from '@mui/icons-material'; import React, {useEffect, useMemo, useState} from 'react'; import {z} from 'zod'; -import {parseRefMaybe} from '../../Browse2/SmallRef'; +import {SelectField} from '../filters/SelectField'; interface ZSFormProps { configSchema: z.ZodType; @@ -52,6 +50,17 @@ const unwrapSchema = (schema: z.ZodTypeAny): z.ZodTypeAny => { return schema; }; +const distiminatorOptionToValue = ( + option: z.ZodTypeAny, + discriminator: string +) => { + return option._def.shape()[discriminator]._def.value; +}; + +const Label: React.FC<{label: string}> = ({label}) => { + return {label} +}; + const DiscriminatedUnionField: React.FC<{ keyName: string; fieldSchema: z.ZodDiscriminatedUnion< @@ -67,7 +76,8 @@ const DiscriminatedUnionField: React.FC<{ const options = fieldSchema._def.options; const currentType = - value?.[discriminator] || options[0]._def.shape()[discriminator]._def.value; + value?.[discriminator] || + distiminatorOptionToValue(options[0], discriminator); const handleTypeChange = (newType: string) => { const selectedOption = options.find( @@ -101,21 +111,19 @@ const DiscriminatedUnionField: React.FC<{ }, {} as Record) ); + + return ( - {keyName} - + onSelectField={v => handleTypeChange(v as string)} + /> s instanceof z.ZodObject)) { return ( - {keyName} +
)} - + value={selectedValue} + onSelectField={v => updateConfig(targetPath, v, config, setConfig)} + /> ); }; @@ -519,7 +515,7 @@ const RecordField: React.FC<{ return ( - {keyName} + {selectedActionSpecConfigurationSpec && ( - +
+ +
)} ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionSpecConfigurations.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionSpecConfigurations.ts index 52381f2623b..6d9859f29cf 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionSpecConfigurations.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/actionSpecConfigurations.ts @@ -17,29 +17,32 @@ import {ActionType} from '../wfReactInterface/generatedBaseObjectClasses.zod'; // Define `ConfiguredLlmJudgeActionFriendlySchema` const SimpleResponseFormatSchema = z - .enum(['boolean', 'number', 'string']) - .default('boolean'); + .enum(['Boolean', 'Number', 'String']) + .default('Boolean'); const StructuredResponseFormatSchema = z.record(SimpleResponseFormatSchema); const ResponseFormatSchema = z.discriminatedUnion('type', [ z.object({ - type: z.literal('simple'), - schema: SimpleResponseFormatSchema, + type: z.literal('Boolean'), }), z.object({ - type: z.literal('structured'), - schema: StructuredResponseFormatSchema, + type: z.literal('Number'), + }), + z.object({ + type: z.literal('String'), + }), + z.object({ + type: z.literal('Object'), + Schema: StructuredResponseFormatSchema, }), ]); const ConfiguredLlmJudgeActionFriendlySchema = z.object({ - model: z - .enum(['gpt-4o-mini', 'gpt-4o']) - .default('gpt-4o-mini') - .describe('Model to use for judging. Please configure OPENAPI_API_KEY in team settings.'), - prompt: z.string().min(3), - response_schema: ResponseFormatSchema, + Model: z.enum(['gpt-4o-mini', 'gpt-4o']).default('gpt-4o-mini'), + // .describe('Model to use for judging. Please configure OPENAPI_API_KEY in team settings.'), + Prompt: z.string().min(3), + 'Response Schema': ResponseFormatSchema, }); // End of `ConfiguredLlmJudgeActionFriendlySchema` @@ -71,12 +74,16 @@ export const actionSpecConfigurationSpecs: Partial< let responseFormat: z.infer< typeof LlmJudgeActionSpecSchema >['response_schema']; - if (data.response_schema.type === 'simple') { - responseFormat = {type: data.response_schema.schema}; + if (data['Response Schema'].type === 'Boolean') { + responseFormat = {type: 'boolean'}; + } else if (data['Response Schema'].type === 'Number') { + responseFormat = {type: 'number'}; + } else if (data['Response Schema'].type === 'String') { + responseFormat = {type: 'string'}; } else { responseFormat = { type: 'object', - properties: _.mapValues(data.response_schema.schema, value => ({ + properties: _.mapValues(data['Response Schema'].Schema, value => ({ type: value as 'boolean' | 'number' | 'string', })), additionalProperties: false, @@ -84,8 +91,8 @@ export const actionSpecConfigurationSpecs: Partial< } return { action_type: 'llm_judge', - model: data.model, - prompt: data.prompt, + model: data.Model, + prompt: data.Prompt, response_schema: responseFormat, }; }, From 0b802cd464573f5338221b3af9f95b2642e8fce6 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 13 Nov 2024 15:43:56 -0800 Subject: [PATCH 088/120] Done with the MVP --- .../Browse3/ZodSchemaForm/ZodSchemaForm.tsx | 113 +++++++++--------- .../ScorersPage/actionSpecConfigurations.ts | 4 +- 2 files changed, 59 insertions(+), 58 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/ZodSchemaForm/ZodSchemaForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/ZodSchemaForm/ZodSchemaForm.tsx index e9e27dc7202..04a10493430 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/ZodSchemaForm/ZodSchemaForm.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/ZodSchemaForm/ZodSchemaForm.tsx @@ -1,21 +1,26 @@ import { Box, - Button, + // Button, Checkbox, FormControl, FormControlLabel, IconButton, InputLabel, - TextField, + // TextField, Tooltip, Typography, } from '@material-ui/core'; import {Delete, Help} from '@mui/icons-material'; +import { Button } from '@wandb/weave/components/Button'; +import { TextField } from '@wandb/weave/components/Form/TextField'; import React, {useEffect, useMemo, useState} from 'react'; import {z} from 'zod'; import {SelectField} from '../filters/SelectField'; +const GAP_BETWEEN_ITEMS_PX = 20; +const GAP_BETWEEN_LABEL_AND_FIELD_PX = 10; + interface ZSFormProps { configSchema: z.ZodType; config: Record; @@ -58,7 +63,7 @@ const distiminatorOptionToValue = ( }; const Label: React.FC<{label: string}> = ({label}) => { - return {label} + return {label} }; const DiscriminatedUnionField: React.FC<{ @@ -114,7 +119,7 @@ const DiscriminatedUnionField: React.FC<{ return ( - +
- ); - } - - return ( -
- -
- ); -}; const useLatestActionDefinitionsForCall = (call: CallSchema) => { const actionSpecs = ( @@ -311,7 +266,7 @@ export const CallScoresViewer: React.FC<{ { field: 'run', headerName: '', - width: 70, + width:75, rowSpanValueGetter: (value, row) => row.displayName, renderCell: params => { const actionRef = params.row.runnableActionRef; @@ -367,3 +322,51 @@ export const CallScoresViewer: React.FC<{ ); }; + + +const RunButton: React.FC<{ + actionRef: string; + callId: string; + entity: string; + project: string; + refetchFeedback: () => void; +}> = ({actionRef, callId, entity, project, refetchFeedback}) => { + const getClient = useGetTraceServerClientContext(); + + const [isRunning, setIsRunning] = useState(false); + const [error, setError] = useState(null); + + const handleRunClick = async () => { + setIsRunning(true); + setError(null); + try { + await getClient().actionsExecuteBatch({ + project_id: projectIdFromParts({entity, project}), + call_ids: [callId], + action_ref: actionRef, + }); + refetchFeedback(); + } catch (err) { + setError('An error occurred while running the action.'); + } finally { + setIsRunning(false); + } + }; + + + if (error) { + return ( + + ); + } + + return ( +
+ +
+ ); +}; \ No newline at end of file From c77b8e7ffe350deb16f3a194b3420ee3c08868e0 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 18 Nov 2024 15:05:31 -0800 Subject: [PATCH 117/120] some final adjustments --- .../Home/Browse3/pages/CallPage/CallPage.tsx | 8 ++++-- .../pages/CallPage/CallScoresViewer.tsx | 20 ++++++++------ .../pages/ScorersPage/NewScorerDrawer.tsx | 26 ++++++++++++++----- .../Browse3/pages/ScorersPage/ScorersPage.tsx | 17 ++++++++---- 4 files changed, 50 insertions(+), 21 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index d29ae8472fe..2e4556f8c4c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -58,9 +58,13 @@ export const CallPage: FC<{ return ; }; -const useCallTabs = (call: CallSchema) => { +export const useShowRunnableUI = () => { const viewerInfo = useViewerInfo(); - const showScores = viewerInfo.loading ? false : viewerInfo.userInfo?.admin; + return viewerInfo.loading ? false : viewerInfo.userInfo?.admin; +}; + +const useCallTabs = (call: CallSchema) => { + const showScores = useShowRunnableUI(); const codeURI = call.opVersionRef; const {entity, project, callId} = call; const weaveRef = makeRefCall(entity, project, callId); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx index daa60c951b6..f477332e3f5 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx @@ -26,8 +26,6 @@ import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; const RUNNABLE_REF_PREFIX = 'wandb.runnable'; - - const useLatestActionDefinitionsForCall = (call: CallSchema) => { const actionSpecs = ( useBaseObjectInstances('ActionSpec', { @@ -266,7 +264,7 @@ export const CallScoresViewer: React.FC<{ { field: 'run', headerName: '', - width:75, + width: 75, rowSpanValueGetter: (value, row) => row.displayName, renderCell: params => { const actionRef = params.row.runnableActionRef; @@ -323,7 +321,6 @@ export const CallScoresViewer: React.FC<{ ); }; - const RunButton: React.FC<{ actionRef: string; callId: string; @@ -353,10 +350,13 @@ const RunButton: React.FC<{ } }; - if (error) { return ( - ); @@ -364,9 +364,13 @@ const RunButton: React.FC<{ return (
-
); -}; \ No newline at end of file +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewScorerDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewScorerDrawer.tsx index 33fbc1f8c17..79887a72bcd 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewScorerDrawer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewScorerDrawer.tsx @@ -1,8 +1,16 @@ import {Box, Drawer} from '@material-ui/core'; import {Button} from '@wandb/weave/components/Button'; import {Icon, IconName, IconNames} from '@wandb/weave/components/Icon'; -import React, {FC, ReactNode, useCallback, useEffect, useState} from 'react'; - +import React, { + FC, + ReactNode, + useCallback, + useEffect, + useMemo, + useState, +} from 'react'; + +import {useShowRunnableUI} from '../CallPage/CallPage'; import {TraceServerClient} from '../wfReactInterface/traceServerClient'; import {useGetTraceServerClientContext} from '../wfReactInterface/traceServerClientContext'; import {AutocompleteWithLabel} from './FormComponents'; @@ -16,7 +24,7 @@ import { const HUMAN_ANNOTATION_LABEL = 'Human annotation'; export const HUMAN_ANNOTATION_VALUE = 'ANNOTATION'; const LLM_JUDGE_LABEL = 'LLM judge'; -const LLM_JUDGE_VALUE = 'LLM_JUDGE'; +export const LLM_JUDGE_VALUE = 'LLM_JUDGE'; const PROGRAMMATIC_LABEL = 'Programmatic scorer'; const PROGRAMMATIC_VALUE = 'PROGRAMMATIC'; @@ -48,7 +56,7 @@ export const scorerTypeRecord: Record> = { }, }, LLM_JUDGE: { - label: LLM_JUDGE_LABEL, + label: LLM_JUDGE_LABEL + ' (W&B Admin Preview)', value: LLM_JUDGE_VALUE, icon: IconNames.RobotServiceMember, Component: LLMJudgeScorerForm.LLMJudgeScorerForm, @@ -124,6 +132,12 @@ export const NewScorerDrawer: FC = ({ }, [selectedScorerType, entity, project, formData, getClient, onClose]); const ScorerFormComponent = scorerTypeRecord[selectedScorerType].Component; + const showRunnableUI = useShowRunnableUI(); + const options = useMemo(() => { + return scorerTypeOptions.filter( + opt => showRunnableUI || opt.value !== LLM_JUDGE_VALUE + ); + }, [showRunnableUI]); return ( = ({ saveDisabled={!isFormValid}> opt.value === selectedScorerType)} + options={options} + value={options.find(opt => opt.value === selectedScorerType)} formatOptionLabel={option => ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx index f7690066e61..adf6ff76816 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ScorersPage.tsx @@ -2,6 +2,7 @@ import {Button} from '@wandb/weave/components/Button'; import {IconNames} from '@wandb/weave/components/Icon'; import React, {useState} from 'react'; +import {useShowRunnableUI} from '../CallPage/CallPage'; import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; import {AnnotationsTab} from './AnnotationsTab'; import {ProgrammaticScorersTab} from './CoreScorersTab'; @@ -22,6 +23,8 @@ export const ScorersPage: React.FC<{ scorerTypeRecord.ANNOTATION.value ); + const showRunnableUI = useShowRunnableUI(); + return ( <> , }, - { - label: scorerTypeRecord.LLM_JUDGE.label + 's', - icon: scorerTypeRecord.LLM_JUDGE.icon, - content: , - }, + ...(showRunnableUI + ? [ + { + label: scorerTypeRecord.LLM_JUDGE.label + 's', + icon: scorerTypeRecord.LLM_JUDGE.icon, + content: , + }, + ] + : []), { label: scorerTypeRecord.PROGRAMMATIC.label + 's', icon: scorerTypeRecord.PROGRAMMATIC.icon, From e9e07fd6afba9b8f0b70e2ba2df855fe3e0e414e Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 18 Nov 2024 15:33:35 -0800 Subject: [PATCH 118/120] Done --- .../Home/Browse3/feedback/FeedbackGrid.tsx | 28 +++++++++++++------ .../Home/Browse3/pages/CallPage/CallPage.tsx | 3 ++ .../pages/CallPage/CallScoresViewer.tsx | 13 +++++---- .../pages/ScorersPage/NewScorerDrawer.tsx | 4 +++ 4 files changed, 35 insertions(+), 13 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx index d4330e51575..38956c7d354 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx @@ -1,12 +1,13 @@ import {Box} from '@mui/material'; import _ from 'lodash'; -import React, {useEffect} from 'react'; +import React, {useEffect, useMemo} from 'react'; import {useViewerInfo} from '../../../../../common/hooks/useViewerInfo'; import {TargetBlank} from '../../../../../common/util/links'; import {Alert} from '../../../../Alert'; import {Loading} from '../../../../Loading'; import {Tailwind} from '../../../../Tailwind'; +import {RUNNABLE_FEEDBACK_TYPE_PREFIX} from '../pages/CallPage/CallScoresViewer'; import {Empty} from '../pages/common/Empty'; import {useWFHooks} from '../pages/wfReactInterface/context'; import {useGetTraceServerClientContext} from '../pages/wfReactInterface/traceServerClientContext'; @@ -40,6 +41,23 @@ export const FeedbackGrid = ({ // eslint-disable-next-line react-hooks/exhaustive-deps }, []); + // Exclude runnables as they are presented in a different tab + const withoutRunnables = useMemo( + () => + (query.result ?? []).filter( + f => !f.feedback_type.startsWith(RUNNABLE_FEEDBACK_TYPE_PREFIX) + ), + [query.result] + ); + + // Group by feedback on this object vs. descendent objects + const grouped = useMemo( + () => + _.groupBy(withoutRunnables, f => f.weave_ref.substring(weaveRef.length)), + [withoutRunnables, weaveRef] + ); + const paths = useMemo(() => Object.keys(grouped).sort(), [grouped]); + if (query.loading || loadingUserInfo) { return ( - f.weave_ref.substring(weaveRef.length) - ); - const paths = Object.keys(grouped).sort(); - const currentViewerId = userInfo ? userInfo.id : null; return ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index 2e4556f8c4c..679f92ae2f3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -136,6 +136,9 @@ const useCallTabs = (call: CallSchema) => { ), }, + // For now, we are only showing this tab for W&B admins since the + // feature is in active development. We want to be able to get + // feedback without enabling for all users. ...(showScores ? [ { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx index f477332e3f5..effb315774a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx @@ -24,7 +24,7 @@ import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; import {objectVersionKeyToRefUri} from '../wfReactInterface/utilities'; import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; -const RUNNABLE_REF_PREFIX = 'wandb.runnable'; +export const RUNNABLE_FEEDBACK_TYPE_PREFIX = 'wandb.runnable'; const useLatestActionDefinitionsForCall = (call: CallSchema) => { const actionSpecs = ( @@ -51,7 +51,7 @@ const useRunnableFeedbacksForCall = (call: CallSchema) => { const runnableFeedbacks: Feedback[] = useMemo(() => { return (feedbackQuery.result ?? []).filter( f => - f.feedback_type?.startsWith(RUNNABLE_REF_PREFIX) && + f.feedback_type?.startsWith(RUNNABLE_FEEDBACK_TYPE_PREFIX) && f.runnable_ref !== null ); }, [feedbackQuery.result]); @@ -67,7 +67,7 @@ const useRunnableFeedbackTypeToLatestActionRef = ( return _.fromPairs( actionSpecs.map(actionSpec => { return [ - RUNNABLE_REF_PREFIX + '.' + actionSpec.object_id, + RUNNABLE_FEEDBACK_TYPE_PREFIX + '.' + actionSpec.object_id, objectVersionKeyToRefUri({ scheme: WEAVE_REF_SCHEME, weaveKind: 'object', @@ -103,7 +103,9 @@ const useTableRowsForRunnableFeedbacks = ( const val = _.reverse(_.sortBy(fs, 'created_at'))[0]; return { id: feedbackType, - displayName: feedbackType.slice(RUNNABLE_REF_PREFIX.length + 1), + displayName: feedbackType.slice( + RUNNABLE_FEEDBACK_TYPE_PREFIX.length + 1 + ), runnableActionRef: runnableFeedbackTypeToLatestActionRef[val.feedback_type], feedback: val, @@ -112,7 +114,8 @@ const useTableRowsForRunnableFeedbacks = ( }); const additionalRows = actionSpecs .map(actionSpec => { - const feedbackType = RUNNABLE_REF_PREFIX + '.' + actionSpec.object_id; + const feedbackType = + RUNNABLE_FEEDBACK_TYPE_PREFIX + '.' + actionSpec.object_id; return { id: feedbackType, runnableActionRef: diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewScorerDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewScorerDrawer.tsx index 79887a72bcd..d8733c87362 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewScorerDrawer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/NewScorerDrawer.tsx @@ -133,6 +133,10 @@ export const NewScorerDrawer: FC = ({ const ScorerFormComponent = scorerTypeRecord[selectedScorerType].Component; const showRunnableUI = useShowRunnableUI(); + + // Here, we hide the LLM judge option from non-admins since the + // feature is in active development. We want to be able to get + // feedback without enabling for all users. const options = useMemo(() => { return scorerTypeOptions.filter( opt => showRunnableUI || opt.value !== LLM_JUDGE_VALUE From 99c6095fcf2eed49403c766ab3b1c72a7e6cb688 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 18 Nov 2024 15:41:43 -0800 Subject: [PATCH 119/120] Done --- wb_schema.gql | 1 + 1 file changed, 1 insertion(+) diff --git a/wb_schema.gql b/wb_schema.gql index ce587bb0ee9..f326ab4d487 100644 --- a/wb_schema.gql +++ b/wb_schema.gql @@ -169,6 +169,7 @@ type User implements Node { photoUrl: String deletedAt: DateTime teams(before: String, after: String, first: Int, last: Int): EntityConnection + admin: Boolean } type UserConnection { From b37c1f064ce3c73bef3531968d775ff2979819ea Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 18 Nov 2024 18:24:01 -0800 Subject: [PATCH 120/120] Merged in --- .../Home/Browse3/feedback/FeedbackGrid.tsx | 38 +++++++++---------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx index e8bf686cb96..72e436c62bc 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx @@ -44,18 +44,17 @@ export const FeedbackGrid = ({ // eslint-disable-next-line react-hooks/exhaustive-deps }, []); - // Group by feedback on this object vs. descendent objects - const grouped = useMemo( - () => - { - // Exclude runnables as they are presented in a different tab - const withoutRunnables = (query.result ?? []).filter( - f => !f.feedback_type.startsWith(RUNNABLE_FEEDBACK_TYPE_PREFIX) - ) - // Combine annotation feedback on (feedback_type, creator) + const grouped = useMemo(() => { + // Exclude runnables as they are presented in a different tab + const withoutRunnables = (query.result ?? []).filter( + f => !f.feedback_type.startsWith(RUNNABLE_FEEDBACK_TYPE_PREFIX) + ); + // Combine annotation feedback on (feedback_type, creator) const combined = _.groupBy( - withoutRunnables.filter(f => f.feedback_type.startsWith(ANNOTATION_PREFIX)), + withoutRunnables.filter(f => + f.feedback_type.startsWith(ANNOTATION_PREFIX) + ), f => `${f.feedback_type}-${f.creator}` ); // only keep the most recent feedback for each (feedback_type, creator) @@ -64,18 +63,16 @@ export const FeedbackGrid = ({ ); // add the non-annotation feedback to the combined object combinedFiltered.push( - ...withoutRunnables.filter(f => !f.feedback_type.startsWith(ANNOTATION_PREFIX)) + ...withoutRunnables.filter( + f => !f.feedback_type.startsWith(ANNOTATION_PREFIX) + ) ); - + // Group by feedback on this object vs. descendent objects return _.groupBy(combinedFiltered, f => - f.weave_ref.substring(weaveRef.length) - ); - }, - [query.result, weaveRef] - ); - - + f.weave_ref.substring(weaveRef.length) + ); + }, [query.result, weaveRef]); const paths = useMemo(() => Object.keys(grouped).sort(), [grouped]); @@ -100,7 +97,7 @@ export const FeedbackGrid = ({ ); } - if (!withoutRunnables.length) { + if (!paths.length) { return (