diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/context.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/context.tsx index 03e7d462cb04..4aba7f882e57 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/context.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/context.tsx @@ -336,7 +336,8 @@ export const browse3ContextGen = ( traceId: string, callId: string, path?: string | null, - tracetree?: boolean + tracetree?: boolean, + feedbackExpand?: boolean ) => { let url = `${projectRoot(entityName, projectName)}/calls/${callId}`; const params = new URLSearchParams(); @@ -346,6 +347,9 @@ export const browse3ContextGen = ( if (tracetree !== undefined) { params.set(TRACETREE_PARAM, tracetree ? '1' : '0'); } + if (feedbackExpand !== undefined) { + params.set(FEEDBACK_EXPAND_PARAM, feedbackExpand ? '1' : '0'); + } if (params.toString()) { url += '?' + params.toString(); } @@ -497,7 +501,8 @@ type RouteType = { traceId: string, callId: string, path?: string | null, - tracetree?: boolean + tracetree?: boolean, + feedbackExpand?: boolean ) => string; tracesUIUrl: (entityName: string, projectName: string) => string; callsUIUrl: ( @@ -564,6 +569,7 @@ const useSetSearchParam = () => { export const PEEK_PARAM = 'peekPath'; export const TRACETREE_PARAM = 'tracetree'; +export const FEEDBACK_EXPAND_PARAM = 'feedbackExpand'; export const PATH_PARAM = 'path'; export const baseContext = browse3ContextGen( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx index d4330e515750..0ea8708ed23d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx @@ -11,6 +11,9 @@ import {Empty} from '../pages/common/Empty'; import {useWFHooks} from '../pages/wfReactInterface/context'; import {useGetTraceServerClientContext} from '../pages/wfReactInterface/traceServerClientContext'; import {FeedbackGridInner} from './FeedbackGridInner'; +import {HUMAN_ANNOTATION_BASE_TYPE} from './StructuredFeedback/humanAnnotationTypes'; + +const ANNOTATION_PREFIX = `${HUMAN_ANNOTATION_BASE_TYPE}.`; type FeedbackGridProps = { entity: string; @@ -81,8 +84,22 @@ export const FeedbackGrid = ({ ); } + // Combine annotation feedback on (feedback_type, creator) + const combined = _.groupBy( + query.result.filter(f => f.feedback_type.startsWith(ANNOTATION_PREFIX)), + f => `${f.feedback_type}-${f.creator}` + ); + // only keep the most recent feedback for each (feedback_type, creator) + const combinedFiltered = Object.values(combined).map( + fs => fs.sort((a, b) => b.created_at - a.created_at)[0] + ); + // add the non-annotation feedback to the combined object + combinedFiltered.push( + ...query.result.filter(f => !f.feedback_type.startsWith(ANNOTATION_PREFIX)) + ); + // Group by feedback on this object vs. descendent objects - const grouped = _.groupBy(query.result, f => + const grouped = _.groupBy(combinedFiltered, f => f.weave_ref.substring(weaveRef.length) ); const paths = Object.keys(grouped).sort(); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGridInner.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGridInner.tsx index fb5b0574564b..5a47b8d1bd28 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGridInner.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGridInner.tsx @@ -42,6 +42,13 @@ export const FeedbackGridInner = ({ if (params.row.feedback_type === 'wandb.reaction.1') { return params.row.payload.emoji; } + if (params.row.feedback_type.startsWith('wandb.annotation.')) { + return ( + + ); + } return ; }, }, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackTypeChip.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackTypeChip.tsx index 2daeb94fda10..aa5f1330d26e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackTypeChip.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackTypeChip.tsx @@ -15,6 +15,9 @@ export const FeedbackTypeChip = ({feedbackType}: FeedbackTypeChipProps) => { } else if (feedbackType === 'wandb.note.1') { color = 'gold'; label = 'Note'; + } else if (feedbackType.includes('wandb.annotation.')) { + color = 'magenta'; + label = 'Annotation'; } return ; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx new file mode 100644 index 000000000000..249a2335704b --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx @@ -0,0 +1,248 @@ +import {toast} from '@wandb/weave/common/components/elements/Toast'; +import {useViewerInfo} from '@wandb/weave/common/hooks/useViewerInfo'; +import {Button} from '@wandb/weave/components/Button'; +import {Icon} from '@wandb/weave/components/Icon'; +import {makeRefCall} from '@wandb/weave/util/refs'; +import React, {useState} from 'react'; +import {useHistory} from 'react-router-dom'; + +import {useWeaveflowRouteContext} from '../../context'; +import {Empty} from '../../pages/common/Empty'; +import {EMPTY_PROPS_ANNOTATIONS} from '../../pages/common/EmptyContent'; +import {HumanAnnotationCell} from './HumanAnnotation'; +import {tsHumanAnnotationSpec} from './humanAnnotationTypes'; + +type FeedbackSidebarProps = { + humanAnnotationSpecs: tsHumanAnnotationSpec[]; + callID: string; + entity: string; + project: string; +}; + +export const FeedbackSidebar = ({ + humanAnnotationSpecs, + callID, + entity, + project, +}: FeedbackSidebarProps) => { + const history = useHistory(); + const router = useWeaveflowRouteContext().baseRouter; + const [isSaving, setIsSaving] = useState(false); + const [unsavedFeedbackChanges, setUnsavedFeedbackChanges] = useState< + Record Promise> + >({}); + + const save = async () => { + setIsSaving(true); + try { + // Save all pending feedback changes + const savePromises = Object.values(unsavedFeedbackChanges).map(saveFn => + saveFn() + ); + const results = await Promise.all(savePromises); + + // Check if any saves failed + if (results.some(result => !result)) { + throw new Error('Not all feedback changes saved'); + } + + // Clear the unsaved changes after successful save + setUnsavedFeedbackChanges({}); + } catch (error) { + console.error('Error saving feedback:', error); + toast(`Error saving feedback: ${error}`, { + type: 'error', + }); + } finally { + setIsSaving(false); + } + }; + + return ( +
+
+
Feedback
+
+
+ {humanAnnotationSpecs.length > 0 ? ( + <> +
+ +
+
+ +
+ + ) : ( +
+ +
+ +
+
+ )} +
+ ); +}; + +type HumanAnnotationSectionProps = { + entity: string; + project: string; + callID: string; + humanAnnotationSpecs: tsHumanAnnotationSpec[]; + setUnsavedFeedbackChanges: React.Dispatch< + React.SetStateAction Promise>> + >; +}; + +const HumanAnnotationSection = ({ + entity, + project, + callID, + humanAnnotationSpecs, + setUnsavedFeedbackChanges, +}: HumanAnnotationSectionProps) => { + const [isExpanded, setIsExpanded] = useState(true); + const sortedVisibleColumns = humanAnnotationSpecs.sort((a, b) => + (a.name ?? '').localeCompare(b.name ?? '') + ); + + return ( +
+ + {isExpanded && ( + + )} +
+ ); +}; + +type HumanAnnotationHeaderProps = { + numHumanAnnotationSpecsVisible: number; + numHumanAnnotationSpecsHidden: number; + isExpanded: boolean; + setIsExpanded: (isExpanded: boolean) => void; +}; + +const HumanAnnotationHeader = ({ + numHumanAnnotationSpecsVisible, + numHumanAnnotationSpecsHidden, + isExpanded, + setIsExpanded, +}: HumanAnnotationHeaderProps) => { + return ( + + ); +}; + +type HumanAnnotationInputsProps = { + entity: string; + project: string; + callID: string; + humanAnnotationSpecs: tsHumanAnnotationSpec[]; + setUnsavedFeedbackChanges: React.Dispatch< + React.SetStateAction Promise>> + >; +}; + +const HumanAnnotationInputs = ({ + entity, + project, + callID, + humanAnnotationSpecs, + setUnsavedFeedbackChanges, +}: HumanAnnotationInputsProps) => { + const callRef = makeRefCall(entity, project, callID); + const {loading: loadingUserInfo, userInfo} = useViewerInfo(); + + if (loadingUserInfo) { + return null; + } + const viewer = userInfo ? userInfo.id : null; + + return ( +
+ {humanAnnotationSpecs?.map((field, index) => ( +
+
{field.name}
+ {field.description && ( +
+ {field.description} +
+ )} +
+ +
+
+ ))} +
+ ); +}; + +const DisplayNumericCounter = ({count}: {count: number}) => { + return ( +
+ {count} +
+ ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx new file mode 100644 index 000000000000..9a1d1438786d --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx @@ -0,0 +1,650 @@ +import {Autocomplete, TextField as MuiTextField} from '@mui/material'; +import {toast} from '@wandb/weave/common/components/elements/Toast'; +import {MOON_300} from '@wandb/weave/common/css/color.styles'; +import {Button} from '@wandb/weave/components/Button'; +import {TextField} from '@wandb/weave/components/Form/TextField'; +import {LoadingDots} from '@wandb/weave/components/LoadingDots'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import {parseRef} from '@wandb/weave/react'; +import _ from 'lodash'; +import React, {useCallback, useEffect, useMemo, useState} from 'react'; + +import {CellValueString} from '../../../Browse2/CellValueString'; +import {useWFHooks} from '../../pages/wfReactInterface/context'; +import {useGetTraceServerClientContext} from '../../pages/wfReactInterface/traceServerClientContext'; +import { + FeedbackCreateError, + FeedbackCreateSuccess, +} from '../../pages/wfReactInterface/traceServerClientTypes'; +import { + FEEDBACK_TYPE_OPTIONS, + HumanAnnotation, + HumanAnnotationPayload, + makeAnnotationFeedbackType, + tsHumanAnnotationSpec, +} from './humanAnnotationTypes'; + +const DEBOUNCE_VAL = 200; + +// Interfaces +type HumanAnnotationProps = { + entity: string; + project: string; + viewer: string | null; + hfSpec: tsHumanAnnotationSpec; + callRef: string; + readOnly?: boolean; + focused?: boolean; + setUnsavedFeedbackChanges: React.Dispatch< + React.SetStateAction Promise>> + >; +}; + +export const HumanAnnotationCell: React.FC = props => { + const getTsClient = useGetTraceServerClientContext(); + const tsClient = getTsClient(); + const {useFeedback} = useWFHooks(); + const [foundFeedback, setFoundFeedback] = useState([]); + const query = useFeedback({ + entity: props.entity, + project: props.project, + weaveRef: props.callRef, + }); + const foundFeedbackCallRef = query?.result?.[0]?.weave_ref; + const feedbackSpecRef = props.hfSpec.ref; + + useEffect(() => { + if (!props.readOnly) { + // We don't need to listen for feedback changes if the cell is editable + // it is being controlled by local state + return; + } + return getTsClient().registerOnFeedbackListener( + props.callRef, + query.refetch + ); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [props.callRef]); + + useEffect(() => { + if (foundFeedbackCallRef && props.callRef !== foundFeedbackCallRef) { + // The call was changed, we need to reset + setFoundFeedback([]); + } + }, [props.callRef, foundFeedbackCallRef]); + + const onAddFeedback = async (value: number | string): Promise => { + try { + const requestProps: FeedbackRequestProps = { + entity: props.entity, + project: props.project, + viewer: props.viewer, + callRef: props.callRef, + feedbackSpecRef, + value, + }; + const createRequest = generateFeedbackRequestPayload(requestProps); + const promise = tsClient.feedbackCreate(createRequest).then(res => { + if ('detail' in res) { + const errorRes = res as FeedbackCreateError; + toast(`Feedback create failed: ${errorRes.detail}`, { + type: 'error', + }); + return false; + } + const successRes = res as FeedbackCreateSuccess; + return !!successRes.id; + }); + return await promise; + } catch (error) { + toast(`Error adding feedback: ${error}`, { + type: 'error', + }); + return false; + } + }; + + useEffect(() => { + if (query?.loading) { + return; + } + + const feedbackRefMatches = (feedback: HumanAnnotation) => + feedback.annotation_ref === feedbackSpecRef; + + const currFeedback = query.result?.filter((feedback: HumanAnnotation) => + feedbackRefMatches(feedback) + ); + if (!currFeedback || currFeedback.length === 0) { + return; + } + + setFoundFeedback(currFeedback); + }, [query?.result, query?.loading, feedbackSpecRef]); + + const extractedValues = useMemo( + () => extractFeedbackValues(foundFeedback, props.viewer, feedbackSpecRef), + [foundFeedback, props.viewer, feedbackSpecRef] + ); + const {rawValues, mostRecentVal, viewerFeedbackVal} = extractedValues; + + const type = useMemo( + () => inferTypeFromJsonSchema(props.hfSpec.json_schema ?? {}), + [props.hfSpec.json_schema] + ); + + if (query?.loading) { + return ; + } + if (props.readOnly) { + return ( +
+ +
+ ); + } + let foundValue = mostRecentVal; + if (props.hfSpec.unique_among_creators) { + foundValue = viewerFeedbackVal; + } + return ( +
+ +
+ ); +}; + +const FeedbackComponentSelector: React.FC<{ + type: string | null; + jsonSchema: Record; + focused: boolean; + onAddFeedback: (value: any) => Promise; + foundValue: string | number | null; + feedbackSpecRef: string; + setUnsavedFeedbackChanges: React.Dispatch< + React.SetStateAction Promise>> + >; +}> = React.memo( + ({ + type, + jsonSchema, + focused, + onAddFeedback, + foundValue, + feedbackSpecRef, + setUnsavedFeedbackChanges, + }) => { + const wrappedOnAddFeedback = useCallback( + async (value: any) => { + setUnsavedFeedbackChanges(curr => ({ + ...curr, + [feedbackSpecRef]: () => onAddFeedback(value), + })); + return true; + }, + [onAddFeedback, setUnsavedFeedbackChanges, feedbackSpecRef] + ); + + switch (type) { + case 'number': + return ( + + ); + case 'string': + return ( + + ); + case 'enum': + return ( + + ); + case 'boolean': + return ( + + ); + default: + return <>; + } + } +); + +type ExtractedFeedbackValues = { + // The leaves of the feedback tree, just the raw values + rawValues: Array; + // The most recent feedback value from the CURRENT viewer + viewerFeedbackVal: string | number | null; + // The most recent feedback value from ANY viewer + mostRecentVal: string | number | null; + // The combined feedback from all viewers + // userId : value + combinedFeedback: Record; +}; + +const extractFeedbackValues = ( + foundFeedback: HumanAnnotation[], + viewer: string | null, + columnRef: string +): ExtractedFeedbackValues => { + // filter out feedback for previous columns, then combine by creator + const combinedFeedback = foundFeedback + .filter(feedback => feedback.annotation_ref === columnRef) + .reduce((acc, feedback) => { + return { + [feedback.creator ?? '']: feedback.payload.value, + ...acc, + }; + }, {}) as Record; + + const rawValues = Object.values(combinedFeedback).filter(Boolean); + const viewerFeedbackVal = combinedFeedback[viewer ?? ''] ?? null; + const mostRecentVal = foundFeedback[0]?.payload.value ?? null; + + return { + rawValues, + viewerFeedbackVal, + mostRecentVal, + combinedFeedback, + }; +}; + +type FeedbackRequestProps = { + entity: string; + project: string; + viewer: string | null; + callRef: string; + feedbackSpecRef: string; + value: any; +}; + +// Utility function for creating feedback request +const generateFeedbackRequestPayload = ({ + entity, + project, + viewer, + callRef, + feedbackSpecRef, + value, +}: FeedbackRequestProps) => { + const parsedRef = parseRef(feedbackSpecRef); + const humanAnnotationPayload: HumanAnnotationPayload = { + value, + }; + const feedbackType = makeAnnotationFeedbackType(parsedRef.artifactName); + const baseRequest = { + project_id: `${entity}/${project}`, + weave_ref: callRef, + creator: viewer, + feedback_type: feedbackType, + annotation_ref: feedbackSpecRef, + payload: humanAnnotationPayload, + sort_by: [{created_at: 'desc'}], + }; + return baseRequest; +}; + +const inferTypeFromJsonSchema = (jsonSchema: Record) => { + if ('enum' in jsonSchema) { + return 'enum'; + } + if (FEEDBACK_TYPE_OPTIONS.includes(jsonSchema.type)) { + return jsonSchema.type; + } + return null; +}; + +export const NumericalFeedbackColumn = ({ + min, + max, + onAddFeedback, + defaultValue, + focused, + isInteger, +}: { + min: number; + max: number; + onAddFeedback?: (value: number | null) => Promise; + defaultValue: number | null; + focused?: boolean; + isInteger?: boolean; +}) => { + const debouncedFn = useMemo( + () => + _.debounce((val: number | null) => onAddFeedback?.(val), DEBOUNCE_VAL), + [onAddFeedback] + ); + useEffect(() => { + return () => { + debouncedFn.cancel(); + }; + }, [debouncedFn]); + + return ( + + ); +}; + +export const TextFeedbackColumn = ({ + onAddFeedback, + defaultValue, + focused, + maxLength, +}: { + onAddFeedback?: (value: string) => Promise; + defaultValue: string | null; + focused?: boolean; + maxLength?: number; +}) => { + const [value, setValue] = useState(defaultValue ?? ''); + + const debouncedFn = useMemo( + () => _.debounce((val: string) => onAddFeedback?.(val), DEBOUNCE_VAL), + [onAddFeedback] + ); + // Cleanup debounce on unmount + useEffect(() => { + return () => { + debouncedFn.cancel(); + }; + }, [debouncedFn]); + + useEffect(() => { + setValue(defaultValue ?? ''); + }, [defaultValue]); + + const onValueChange = useCallback( + (newValue: string) => { + setValue(newValue); + debouncedFn(newValue); + }, + [debouncedFn] + ); + + return ( +
+ + {maxLength && ( +
+ {`character max: ${maxLength}`} +
+ )} +
+ ); +}; + +type Option = { + label: string; + value: string; +}; + +export const EnumFeedbackColumn = ({ + options, + onAddFeedback, + defaultValue, + focused, +}: { + options: string[]; + onAddFeedback?: (value: string) => Promise; + defaultValue: string | null; + focused?: boolean; +}) => { + const dropdownOptions = useMemo(() => { + const opts = options.map((option: string) => ({ + label: option, + value: option, + })); + opts.splice(0, 0, {label: '', value: ''}); + return opts; + }, [options]); + + const [value, setValue] = useState