diff --git a/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx b/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx index 551bb638c3..911b9cd923 100644 --- a/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx +++ b/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx @@ -19,7 +19,9 @@ import { Box, Button, Card, + Checkbox, CircularProgress, + FormControlLabel, IconButton, MenuItem, Select, @@ -44,6 +46,7 @@ import { formatValue, getNumericDifference, useTriggerEvent } from "./utils"; const KEY_COLOR = "#ff6d04"; const COMPARE_KEY_COLOR = "#03a9f4"; +const DEFAULT_BAR_CONFIG = { sortBy: "az" }; export default function Evaluation(props: EvaluationProps) { const { @@ -64,15 +67,14 @@ export default function Evaluation(props: EvaluationProps) { const [expanded, setExpanded] = React.useState("summary"); const [mode, setMode] = useState("chart"); const [editNoteState, setEditNoteState] = useState({ open: false, note: "" }); - const [barConfigState, setBarConfigState] = useState({ - sortBy: "", - limit: 0, - }); - const [barConfigDialogState, setBarConfigDialogState] = useState({ - open: false, - sortBy: "best", - limit: 20, - }); + const [classPerformanceConfig, setClassPerformanceConfig] = + useState({}); + const [classPerformanceDialogConfig, setClassPerformanceDialogConfig] = + useState(DEFAULT_BAR_CONFIG); + const [confusionMatrixConfig, setConfusionMatrixConfig] = + useState({ log: true }); + const [confusionMatrixDialogConfig, setConfusionMatrixDialogConfig] = + useState(DEFAULT_BAR_CONFIG); const [metricMode, setMetricMode] = useState("chart"); const [classMode, setClassMode] = useState("chart"); const [performanceClass, setPerformanceClass] = useState("precision"); @@ -87,11 +89,14 @@ export default function Evaluation(props: EvaluationProps) { return evaluation; }, [data]); const confusionMatrix = useMemo(() => { - return evaluation?.confusion_matrix; - }, [evaluation]); + return getMatrix(evaluation?.confusion_matrices, confusionMatrixConfig); + }, [evaluation, confusionMatrixConfig]); const compareConfusionMatrix = useMemo(() => { - return compareEvaluation?.confusion_matrix; - }, [compareEvaluation]); + return getMatrix( + compareEvaluation?.confusion_matrices, + confusionMatrixConfig + ); + }, [compareEvaluation, confusionMatrixConfig]); const compareKeys = useMemo(() => { const keys: string[] = []; const evaluations = data?.evaluations || []; @@ -129,8 +134,11 @@ export default function Evaluation(props: EvaluationProps) { const closeNoteDialog = () => { setEditNoteState((note) => ({ ...note, open: false })); }; - const closeBarConfigDialog = () => { - setBarConfigDialogState((state) => ({ ...state, open: false })); + const closeClassPerformanceConfigDialog = () => { + setClassPerformanceDialogConfig((state) => ({ ...state, open: false })); + }; + const closeConfusionMatrixConfigDialog = () => { + setConfusionMatrixDialogConfig((state) => ({ ...state, open: false })); }; if (!evaluation) { @@ -400,7 +408,7 @@ export default function Evaluation(props: EvaluationProps) { const performanceClasses = Object.keys(perClassPerformance); const classPerformance = formatPerClassPerformance( perClassPerformance[performanceClass], - barConfigState + classPerformanceConfig ); return ( @@ -608,7 +616,13 @@ export default function Evaluation(props: EvaluationProps) { 1 ); const positiveRatio = ratio > 0; - const ratioColor = positiveRatio ? "#8BC18D" : "#FF6464"; + const zeroRatio = ratio === 0; + const negativeRatio = ratio < 0; + const ratioColor = positiveRatio + ? "#8BC18D" + : negativeRatio + ? "#FF6464" + : theme.palette.text.tertiary; const showTrophy = lesserIsBetter ? difference < 0 : difference > 0; @@ -635,13 +649,13 @@ export default function Evaluation(props: EvaluationProps) { )} {filterable && ( { loadView("field", { field: rowId }); }} title="Load view" > - + )} @@ -661,7 +675,10 @@ export default function Evaluation(props: EvaluationProps) { {filterable && ( { loadView("field", { field: rowId, @@ -686,13 +703,22 @@ export default function Evaluation(props: EvaluationProps) { direction="row" sx={{ alignItems: "center" }} > - {positiveRatio ? ( + {positiveRatio && ( - ) : ( + )} + {negativeRatio && ( )} + {zeroRatio && ( + + — + + )} {CLASS_LABELS[performanceClass]} Per Class + {getConfigLabel({ + config: classPerformanceConfig, + type: "classPerformance", + dashed: true, + })} { - setBarConfigDialogState((state) => ({ + setClassPerformanceDialogConfig((state) => ({ ...state, open: true, })); @@ -1039,6 +1070,23 @@ export default function Evaluation(props: EvaluationProps) { Confusion Matrices + + + {getConfigLabel({ config: confusionMatrixConfig })} + + + { + setConfusionMatrixDialogConfig((state) => ({ + ...state, + open: true, + })); + }} + > + + + + { @@ -1082,6 +1132,9 @@ export default function Evaluation(props: EvaluationProps) { x: compareConfusionMatrix?.labels, y: compareConfusionMatrix?.labels, type: "heatmap", + colorscale: confusionMatrixConfig.log + ? compareConfusionMatrix?.colorscale || "viridis" + : "viridis", }, ]} /> @@ -1163,7 +1216,7 @@ export default function Evaluation(props: EvaluationProps) { { @@ -1194,9 +1247,9 @@ export default function Evaluation(props: EvaluationProps) { theme.palette.background.level2 }, }} @@ -1211,50 +1264,141 @@ export default function Evaluation(props: EvaluationProps) { Limit bars: { - setBarConfigDialogState((state) => ({ + const newLimit = parseInt(e.target.value); + setClassPerformanceDialogConfig((state) => { + return { + ...state, + limit: isNaN(newLimit) ? undefined : newLimit, + }; + }); + }} + /> + + + + + + + + theme.palette.background.level2 }, + }} + > + + + + Display Options: Confusion Matrix + + + Sort by: + + + + Limit classes: + { + const newLimit = parseInt(e.target.value); + setConfusionMatrixDialogConfig((state) => { + return { + ...state, + limit: isNaN(newLimit) ? undefined : newLimit, + }; + }); + }} /> + { + setConfusionMatrixDialogConfig((state) => ({ + ...state, + log: checked, + })); + }} + /> + } + />