diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 69923d04c..60dd6ea04 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -58,7 +58,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: '3.x' + python-version: '3.11' architecture: x64 - name: Install dependencies run: | diff --git a/docs/tutorials/preferential-optimization.rst b/docs/tutorials/preferential-optimization.rst index a192e883c..ae6623831 100644 --- a/docs/tutorials/preferential-optimization.rst +++ b/docs/tutorials/preferential-optimization.rst @@ -5,11 +5,11 @@ What is Preferential Optimization? ---------------------------------- Preferential optimization is a method for optimizing hyperparameters, focusing of human preferences, by determining which trial is superior when comparing a pair. -It differs from `human-in-the-loop optimization utilizing objective form widgets `_, +It differs from :ref:`human-in-the-loop optimization utilizing objective form widgets `, which relies on absolute evaluations, as it significantly reduces fluctuations in evaluators' criteria, thus ensuring more consistent results. In this tutorial, we'll interactively optimize RGB values to generate a color resembling a "sunset hue", -aligining with the problem setting in `this tutorial `_. +aligining with the problem setting in :ref:`this tutorial `. Familiarity with the tutorial ob objective form widgets may enhance your understanding. How to Run Preferential Optimization diff --git a/optuna_dashboard/preferential/samplers/gp.py b/optuna_dashboard/preferential/samplers/gp.py index beb5b7c23..3cd980614 100644 --- a/optuna_dashboard/preferential/samplers/gp.py +++ b/optuna_dashboard/preferential/samplers/gp.py @@ -50,15 +50,15 @@ def _orthants_MVN_Gibbs_sampling(cov_inv: Tensor, cycles: int, initial_sample: T def _one_side_trunc_norm_sampling(lower: Tensor) -> Tensor: - if lower > 4.0: - r = torch.clamp_min(torch.rand(torch.Size(()), dtype=torch.float64), min=1e-300) - return (lower * lower - 2 * r.log()).sqrt() - else: - SQRT2 = math.sqrt(2) - r = torch.rand(torch.Size(()), dtype=torch.float64) * torch.erfc(lower / SQRT2) - while 1 - r == 1: - r = torch.rand(torch.Size(()), dtype=torch.float64) * torch.erfc(lower / SQRT2) - return torch.erfinv(1 - r) * SQRT2 + r = torch.rand(torch.Size(()), dtype=torch.float64) + ret = -torch.special.ndtri(torch.exp(torch.special.log_ndtr(-lower) + r.log())) + + # If sampled random number is very small, `ret` becomes inf. + while torch.isinf(ret): + r = torch.rand(torch.Size(()), dtype=torch.float64) + ret = -torch.special.ndtri(torch.exp(torch.special.log_ndtr(-lower) + r.log())) + + return ret _orthants_MVN_Gibbs_sampling_jit = torch.jit.script(_orthants_MVN_Gibbs_sampling) diff --git a/optuna_dashboard/ts/components/DataGrid.tsx b/optuna_dashboard/ts/components/DataGrid.tsx index 2d6a5670c..24e45dc78 100644 --- a/optuna_dashboard/ts/components/DataGrid.tsx +++ b/optuna_dashboard/ts/components/DataGrid.tsx @@ -28,7 +28,7 @@ interface DataGridColumn { field: keyof T label: string sortable?: boolean - less?: (a: T, b: T) => number + less?: (a: T, b: T, ascending: boolean) => number filterable?: boolean toCellValue?: (rowIndex: number) => string | React.ReactNode padding?: "normal" | "checkbox" | "none" @@ -358,7 +358,10 @@ function stableSort( const stabilizedThis = array.map((el, index) => [el, index] as [T, number]) stabilizedThis.sort((a, b) => { if (less) { - const result = order == "asc" ? -less(a[0], b[0]) : less(a[0], b[0]) + const ascending = order == "asc" + const result = ascending + ? -less(a[0], b[0], ascending) + : less(a[0], b[0], ascending) if (result !== 0) return result } else { const result = comparator(a[0], b[0]) diff --git a/optuna_dashboard/ts/components/GraphContour.tsx b/optuna_dashboard/ts/components/GraphContour.tsx index b186fffeb..a5416f5b8 100644 --- a/optuna_dashboard/ts/components/GraphContour.tsx +++ b/optuna_dashboard/ts/components/GraphContour.tsx @@ -195,12 +195,16 @@ const plotContour = ( const xValues: plotly.Datum[] = [] const yValues: plotly.Datum[] = [] const zValues: plotly.Datum[][] = new Array(yIndices.length) + const feasibleXY = new Set() for (let j = 0; j < yIndices.length; j++) { zValues[j] = new Array(xIndices.length).fill(null) } filteredTrials.forEach((trial, i) => { if (xAxis.values[i] && yAxis.values[i] && trial.values) { + if (trial.constraints.every((c) => c <= 0)) { + feasibleXY.add(xValues.length) + } const xValue = xAxis.values[i] as string | number const yValue = yAxis.values[i] as string | number xValues.push(xValue) @@ -234,12 +238,20 @@ const plotContour = ( }, { type: "scatter", - x: xValues, - y: yValues, + x: xValues.filter((_, i) => feasibleXY.has(i)), + y: yValues.filter((_, i) => feasibleXY.has(i)), marker: { line: { width: 2.0, color: "Grey" }, color: "black" }, mode: "markers", showlegend: false, }, + { + type: "scatter", + x: xValues.filter((_, i) => !feasibleXY.has(i)), + y: yValues.filter((_, i) => !feasibleXY.has(i)), + marker: { line: { width: 2.0, color: "Grey" }, color: "#cccccc" }, + mode: "markers", + showlegend: false, + }, ] plotly.react(plotDomId, plotData, layout) return diff --git a/optuna_dashboard/ts/components/GraphIntermediateValues.tsx b/optuna_dashboard/ts/components/GraphIntermediateValues.tsx index ca0bfda0c..2d86e4649 100644 --- a/optuna_dashboard/ts/components/GraphIntermediateValues.tsx +++ b/optuna_dashboard/ts/components/GraphIntermediateValues.tsx @@ -84,16 +84,21 @@ const plotIntermediateValue = ( const values = trial.intermediate_values.filter( (iv) => iv.value !== "inf" && iv.value !== "-inf" && iv.value !== "nan" ) + const isFeasible = trial.constraints.every((c) => c <= 0) return { x: values.map((iv) => iv.step), y: values.map((iv) => iv.value), marker: { maxdisplayed: 10 }, mode: "lines+markers", type: "scatter", - name: - trial.state !== "Running" - ? `trial #${trial.number}` - : `trial #${trial.number} (running)`, + name: `trial #${trial.number} ${ + trial.state === "Running" + ? "(running)" + : !isFeasible + ? "(infeasible)" + : "" + }`, + ...(!isFeasible && { line: { color: "#CCCCCC" } }), } }) plotly.react(plotDomId, plotData, layout) diff --git a/optuna_dashboard/ts/components/GraphRank.tsx b/optuna_dashboard/ts/components/GraphRank.tsx new file mode 100644 index 000000000..6781793fe --- /dev/null +++ b/optuna_dashboard/ts/components/GraphRank.tsx @@ -0,0 +1,414 @@ +import * as plotly from "plotly.js-dist-min" +import React, { FC, useEffect, useState } from "react" +import { + Grid, + FormControl, + FormLabel, + MenuItem, + Select, + Typography, + SelectChangeEvent, + useTheme, + Box, +} from "@mui/material" +import { plotlyDarkTemplate } from "./PlotlyDarkMode" +import { makeHovertext } from "../graphUtil" +import { useMergedUnionSearchSpace } from "../searchSpace" + +const PADDING_RATIO = 0.05 +const plotDomId = "graph-rank" + +interface AxisInfo { + name: string + range: [number, number] + isLog: boolean + isCat: boolean +} + +interface RankPlotInfo { + xaxis: AxisInfo + yaxis: AxisInfo + xvalues: (string | number)[] + yvalues: (string | number)[] + zvalues: number[] + colors: number[] + is_feasible: boolean[] + hovertext: string[] +} + +export const GraphRank: FC<{ + study: StudyDetail | null +}> = ({ study = null }) => { + const theme = useTheme() + const [objectiveId, setobjectiveId] = useState(0) + const searchSpace = useMergedUnionSearchSpace(study?.union_search_space) + const [xParam, setXParam] = useState(null) + const [yParam, setYParam] = useState(null) + const objectiveNames: string[] = study?.objective_names || [] + + if (xParam === null && searchSpace.length > 0) { + setXParam(searchSpace[0]) + } + if (yParam === null && searchSpace.length > 1) { + setYParam(searchSpace[1]) + } + + const handleObjectiveChange = (event: SelectChangeEvent) => { + setobjectiveId(Number(event.target.value)) + } + const handleXParamChange = (event: SelectChangeEvent) => { + const param = searchSpace.find((item) => item.name === event.target.value) + setXParam(param || null) + } + const handleYParamChange = (event: SelectChangeEvent) => { + const param = searchSpace.find((item) => item.name === event.target.value) + setYParam(param || null) + } + + useEffect(() => { + if (study != null) { + const rankPlotInfo = getRankPlotInfo(study, objectiveId, xParam, yParam) + plotRank(rankPlotInfo, theme.palette.mode) + } + }, [study, objectiveId, xParam, yParam, theme.palette.mode]) + + const space: SearchSpaceItem[] = study ? study.union_search_space : [] + + return ( + + + + Rank + + {study !== null && study.directions.length !== 1 ? ( + + Objective: + + + ) : null} + {study !== null && space.length > 0 ? ( + + + x: + + + + y: + + + + ) : null} + + + + + + ) +} + +const getRankPlotInfo = ( + study: StudyDetail | null, + objectiveId: number, + xParam: SearchSpaceItem | null, + yParam: SearchSpaceItem | null +): RankPlotInfo | null => { + if (study === null) { + return null + } + + const trials = study.trials + const filteredTrials = trials.filter(filterFunc) + if (filteredTrials.length < 2 || xParam === null || yParam === null) { + return null + } + + const xAxis = getAxisInfo(filteredTrials, xParam) + const yAxis = getAxisInfo(filteredTrials, yParam) + + const xValues: (string | number)[] = [] + const yValues: (string | number)[] = [] + const zValues: number[] = [] + const isFeasible: boolean[] = [] + const hovertext: string[] = [] + filteredTrials.forEach((trial) => { + const xValue = + trial.params.find((p) => p.name === xAxis.name)?.param_external_value || + null + const yValue = + trial.params.find((p) => p.name === yAxis.name)?.param_external_value || + null + if (trial.values === undefined || xValue === null || yValue === null) { + return + } + const zValue = Number(trial.values[objectiveId]) + const feasibility = trial.constraints.every((c) => c <= 0) + xValues.push(xValue) + yValues.push(yValue) + zValues.push(zValue) + isFeasible.push(feasibility) + hovertext.push(makeHovertext(trial)) + }) + + const colors = getColors(zValues) + + return { + xaxis: xAxis, + yaxis: yAxis, + xvalues: xValues, + yvalues: yValues, + zvalues: zValues, + colors, + is_feasible: isFeasible, + hovertext, + } +} + +const filterFunc = (trial: Trial): boolean => { + return trial.state === "Complete" && trial.values !== undefined +} + +const getAxisInfo = (trials: Trial[], param: SearchSpaceItem): AxisInfo => { + if (param.distribution.type === "CategoricalDistribution") { + return getAxisInfoForCategorical(trials, param.name, param.distribution) + } else { + return getAxisInfoForNumerical(trials, param.name, param.distribution) + } +} + +const getAxisInfoForCategorical = ( + trials: Trial[], + param: string, + distribution: CategoricalDistribution +): AxisInfo => { + const values = trials.map( + (trial) => + trial.params.find((p) => p.name === param)?.param_internal_value || null + ) + const isDynamic = values.some((v) => v === null) + const span = distribution.choices.length - (isDynamic ? 2 : 1) + const padding = span * PADDING_RATIO + const min = -padding + const max = span + padding + + return { + name: param, + range: [min, max], + isLog: false, + isCat: true, + } +} + +const getAxisInfoForNumerical = ( + trials: Trial[], + param: string, + distribution: FloatDistribution | IntDistribution +): AxisInfo => { + const values = trials.map( + (trial) => + trial.params.find((p) => p.name === param)?.param_internal_value || null + ) + const nonNullValues: number[] = [] + values.forEach((value) => { + if (value !== null) { + nonNullValues.push(value) + } + }) + let min = Math.min(...nonNullValues) + let max = Math.max(...nonNullValues) + if (distribution.log) { + const padding = (Math.log10(max) - Math.log10(min)) * PADDING_RATIO + min = Math.pow(10, Math.log10(min) - padding) + max = Math.pow(10, Math.log10(max) + padding) + } else { + const padding = (max - min) * PADDING_RATIO + min = min - padding + max = max + padding + } + + return { + name: param, + range: [min, max], + isLog: distribution.log, + isCat: false, + } +} + +const getColors = (values: number[]): number[] => { + const rawRanks = getOrderWithSameOrderAveraging(values) + let colorIdxs: number[] = [] + if (values.length > 2) { + colorIdxs = rawRanks.map((rank) => rank / (values.length - 1)) + } else { + colorIdxs = [0.5] + } + return colorIdxs +} + +const getOrderWithSameOrderAveraging = (values: number[]): number[] => { + const sortedValues = values.slice().sort() + const ranks: number[] = [] + values.forEach((value) => { + const firstIndex = sortedValues.indexOf(value) + const lastIndex = sortedValues.lastIndexOf(value) + const sumOfTheValue = + ((firstIndex + lastIndex) * (lastIndex - firstIndex + 1)) / 2 + const rank = sumOfTheValue / (lastIndex - firstIndex + 1) + ranks.push(rank) + }) + return ranks +} + +const plotRank = (rankPlotInfo: RankPlotInfo | null, mode: string) => { + if (document.getElementById(plotDomId) === null) { + return + } + + if (rankPlotInfo === null) { + plotly.react(plotDomId, [], { + template: mode === "dark" ? plotlyDarkTemplate : {}, + }) + return + } + + const xAxis = rankPlotInfo.xaxis + const yAxis = rankPlotInfo.yaxis + const layout: Partial = { + xaxis: { + title: xAxis.name, + type: xAxis.isCat ? "category" : xAxis.isLog ? "log" : "linear", + }, + yaxis: { + title: yAxis.name, + type: yAxis.isCat ? "category" : yAxis.isLog ? "log" : "linear", + }, + margin: { + l: 50, + t: 0, + r: 50, + b: 50, + }, + uirevision: "true", + template: mode === "dark" ? plotlyDarkTemplate : {}, + } + + let xValues = rankPlotInfo.xvalues + let yValues = rankPlotInfo.yvalues + if (xAxis.isCat && !yAxis.isCat) { + const xIndices: number[] = Array.from(Array(xValues.length).keys()).sort( + (a, b) => + xValues[a] + .toString() + .toLowerCase() + .localeCompare(xValues[b].toString().toLowerCase()) + ) + xValues = xIndices.map((i) => xValues[i]) + yValues = xIndices.map((i) => yValues[i]) + } + if (!xAxis.isCat && yAxis.isCat) { + const yIndices: number[] = Array.from(Array(yValues.length).keys()).sort( + (a, b) => + yValues[a] + .toString() + .toLowerCase() + .localeCompare(yValues[b].toString().toLowerCase()) + ) + xValues = yIndices.map((i) => xValues[i]) + yValues = yIndices.map((i) => yValues[i]) + } + if (xAxis.isCat && yAxis.isCat) { + const indices: number[] = Array.from(Array(xValues.length).keys()).sort( + (a, b) => { + const xComp = xValues[a] + .toString() + .toLowerCase() + .localeCompare(xValues[b].toString().toLowerCase()) + if (xComp !== 0) { + return xComp + } + return yValues[a] + .toString() + .toLowerCase() + .localeCompare(yValues[b].toString().toLowerCase()) + } + ) + xValues = indices.map((i) => xValues[i]) + yValues = indices.map((i) => yValues[i]) + } + + const plotData: Partial[] = [ + { + type: "scatter", + x: xValues.filter((_, i) => rankPlotInfo.is_feasible[i]), + y: yValues.filter((_, i) => rankPlotInfo.is_feasible[i]), + marker: { + color: rankPlotInfo.colors.filter( + (_, i) => rankPlotInfo.is_feasible[i] + ), + colorscale: "Portland", + colorbar: { + title: "Rank", + }, + size: 10, + line: { + color: "Grey", + width: 0.5, + }, + }, + mode: "markers", + showlegend: false, + hovertemplate: "%{hovertext}", + hovertext: rankPlotInfo.hovertext.filter( + (_, i) => rankPlotInfo.is_feasible[i] + ), + }, + { + type: "scatter", + x: xValues.filter((_, i) => !rankPlotInfo.is_feasible[i]), + y: yValues.filter((_, i) => !rankPlotInfo.is_feasible[i]), + marker: { + color: "#cccccc", + size: 10, + line: { + color: "Grey", + width: 0.5, + }, + }, + mode: "markers", + showlegend: false, + hovertemplate: "%{hovertext}", + hovertext: rankPlotInfo.hovertext.filter( + (_, i) => !rankPlotInfo.is_feasible[i] + ), + }, + ] + plotly.react(plotDomId, plotData, layout) +} diff --git a/optuna_dashboard/ts/components/GraphSlice.tsx b/optuna_dashboard/ts/components/GraphSlice.tsx index f37a54136..2a807cb71 100644 --- a/optuna_dashboard/ts/components/GraphSlice.tsx +++ b/optuna_dashboard/ts/components/GraphSlice.tsx @@ -188,65 +188,71 @@ const plotSlice = ( return } - const objectiveValues: number[] = trials.map( + const feasibleTrials: Trial[] = [] + const infeasibleTrials: Trial[] = [] + trials.forEach((t) => { + if (t.constraints.every((c) => c <= 0)) { + feasibleTrials.push(t) + } else { + infeasibleTrials.push(t) + } + }) + + const feasibleObjectiveValues: number[] = feasibleTrials.map( (t) => objectiveTarget.getTargetValue(t) as number ) - const values = trials.map( - (t) => selectedParamTarget.getTargetValue(t) as number + const infeasibleObjectiveValues: number[] = infeasibleTrials.map( + (t) => objectiveTarget.getTargetValue(t) as number ) - const trialNumbers: number[] = trials.map((t) => t.number) - if (selectedParamSpace.distribution.type !== "CategoricalDistribution") { - const trace: plotly.Data[] = [ - { - type: "scatter", - x: values, - y: objectiveValues, - mode: "markers", - marker: { - color: trialNumbers, - colorscale: "Blues", - reversescale: true, - colorbar: { - title: "Trial", - }, - line: { - color: "Grey", - width: 0.5, - }, + const feasibleValues = feasibleTrials.map( + (t) => selectedParamTarget.getTargetValue(t) as number + ) + const infeasibleValues = infeasibleTrials.map( + (t) => selectedParamTarget.getTargetValue(t) as number + ) + const trace: plotly.Data[] = [ + { + type: "scatter", + x: feasibleValues, + y: feasibleObjectiveValues, + mode: "markers", + name: "Feasible Trial", + marker: { + color: feasibleTrials.map((t) => t.number), + colorscale: "Blues", + reversescale: true, + colorbar: { + title: "Trial", + }, + line: { + color: "Grey", + width: 0.5, }, }, - ] + }, + { + type: "scatter", + x: infeasibleValues, + y: infeasibleObjectiveValues, + mode: "markers", + name: "Infeasible Trial", + marker: { + color: "#cccccc", + reversescale: true, + }, + }, + ] + if (selectedParamSpace.distribution.type !== "CategoricalDistribution") { layout["xaxis"] = { title: selectedParamTarget.toLabel(), type: isLogScale(selectedParamSpace) ? "log" : "linear", gridwidth: 1, automargin: true, // Otherwise the label is outside of the plot } - plotly.react(plotDomId, trace, layout) } else { const vocabArr = selectedParamSpace.distribution.choices.map((c) => c.value) const tickvals: number[] = vocabArr.map((v, i) => i) - const trace: plotly.Data[] = [ - { - type: "scatter", - x: values, - y: objectiveValues, - mode: "markers", - marker: { - color: trialNumbers, - colorscale: "Blues", - reversescale: true, - colorbar: { - title: "Trial", - }, - line: { - color: "Grey", - width: 0.5, - }, - }, - }, - ] layout["xaxis"] = { title: selectedParamTarget.toLabel(), type: "linear", @@ -255,6 +261,6 @@ const plotSlice = ( ticktext: vocabArr, automargin: true, // Otherwise the label is outside of the plot } - plotly.react(plotDomId, trace, layout) } + plotly.react(plotDomId, trace, layout) } diff --git a/optuna_dashboard/ts/components/StudyDetail.tsx b/optuna_dashboard/ts/components/StudyDetail.tsx index 1a8ced334..663ff3993 100644 --- a/optuna_dashboard/ts/components/StudyDetail.tsx +++ b/optuna_dashboard/ts/components/StudyDetail.tsx @@ -27,6 +27,7 @@ import { GraphParallelCoordinate } from "./GraphParallelCoordinate" import { Contour } from "./GraphContour" import { GraphSlice } from "./GraphSlice" import { GraphEdf } from "./GraphEdf" +import { GraphRank } from "./GraphRank" import { TrialList } from "./TrialList" import { StudyHistory } from "./StudyHistory" import { PreferentialTrials } from "./PreferentialTrials" @@ -121,6 +122,11 @@ export const StudyDetail: FC<{ + + + + + Empirical Distribution of the Objective Value diff --git a/optuna_dashboard/ts/components/TrialTable.tsx b/optuna_dashboard/ts/components/TrialTable.tsx index 8fcce63c8..c708f6ee0 100644 --- a/optuna_dashboard/ts/components/TrialTable.tsx +++ b/optuna_dashboard/ts/components/TrialTable.tsx @@ -28,7 +28,7 @@ export const TrialTable: FC<{ field: "values", label: "Value", sortable: true, - less: (firstEl, secondEl): number => { + less: (firstEl, secondEl, ascending): number => { const firstVal = firstEl.values?.[0] const secondVal = secondEl.values?.[0] @@ -36,9 +36,9 @@ export const TrialTable: FC<{ return 0 } if (firstVal === undefined) { - return -1 + return ascending ? -1 : 1 } else if (secondVal === undefined) { - return 1 + return ascending ? 1 : -1 } if (firstVal === "-inf" || secondVal === "inf") { return 1 @@ -63,7 +63,7 @@ export const TrialTable: FC<{ ? objectiveNames[objectiveId] : `Objective ${objectiveId}`, sortable: true, - less: (firstEl, secondEl): number => { + less: (firstEl, secondEl, ascending): number => { const firstVal = firstEl.values?.[objectiveId] const secondVal = secondEl.values?.[objectiveId] @@ -71,9 +71,9 @@ export const TrialTable: FC<{ return 0 } if (firstVal === undefined) { - return -1 + return ascending ? -1 : 1 } else if (secondVal === undefined) { - return 1 + return ascending ? 1 : -1 } if (firstVal === "-inf" || secondVal === "inf") { return 1 @@ -106,7 +106,8 @@ export const TrialTable: FC<{ ?.param_external_value || null, sortable: sortable, filterable: filterable, - less: (firstEl, secondEl): number => { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + less: (firstEl, secondEl, _): number => { const firstVal = firstEl.params.find( (p) => p.name === s.name )?.param_internal_value @@ -146,7 +147,8 @@ export const TrialTable: FC<{ ?.value || null, sortable: attr_spec.sortable, filterable: !attr_spec.sortable, - less: (firstEl, secondEl): number => { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + less: (firstEl, secondEl, _): number => { const firstVal = firstEl.user_attrs.find( (attr) => attr.key === attr_spec.key )?.value diff --git a/python_tests/preferential/samplers/test_gp.py b/python_tests/preferential/samplers/test_gp.py new file mode 100644 index 000000000..136959a79 --- /dev/null +++ b/python_tests/preferential/samplers/test_gp.py @@ -0,0 +1,29 @@ +import sys +from unittest.mock import patch + +import numpy as np +import pytest + + +if sys.version_info >= (3, 8): + from optuna_dashboard.preferential.samplers.gp import _one_side_trunc_norm_sampling + import torch +else: + pytest.skip("BoTorch dropped Python3.7 support", allow_module_level=True) + + +def test_one_side_trunc_norm_sampling() -> None: + for lower in np.linspace(-10, 10, 100): + assert _one_side_trunc_norm_sampling(torch.tensor([lower], dtype=torch.float64)) >= lower + + with patch.object(torch, "rand", return_value=torch.tensor([0.4], dtype=torch.float64)): + sampled_value = _one_side_trunc_norm_sampling(torch.tensor([0.1], dtype=torch.float64)) + assert np.allclose(sampled_value.numpy(), 0.899967154837563) + + with patch.object(torch, "rand", return_value=torch.tensor([0.8], dtype=torch.float64)): + sampled_value = _one_side_trunc_norm_sampling(torch.tensor([-2.3], dtype=torch.float64)) + assert np.allclose(sampled_value.numpy(), -0.8113606739551955) + + with patch.object(torch, "rand", return_value=torch.tensor([0.1], dtype=torch.float64)): + sampled_value = _one_side_trunc_norm_sampling(torch.tensor([5], dtype=torch.float64)) + assert np.allclose(sampled_value.numpy(), 5.426934003050024) diff --git a/standalone_app/src/components/DataGrid.tsx b/standalone_app/src/components/DataGrid.tsx index 2d6a5670c..24e45dc78 100644 --- a/standalone_app/src/components/DataGrid.tsx +++ b/standalone_app/src/components/DataGrid.tsx @@ -28,7 +28,7 @@ interface DataGridColumn { field: keyof T label: string sortable?: boolean - less?: (a: T, b: T) => number + less?: (a: T, b: T, ascending: boolean) => number filterable?: boolean toCellValue?: (rowIndex: number) => string | React.ReactNode padding?: "normal" | "checkbox" | "none" @@ -358,7 +358,10 @@ function stableSort( const stabilizedThis = array.map((el, index) => [el, index] as [T, number]) stabilizedThis.sort((a, b) => { if (less) { - const result = order == "asc" ? -less(a[0], b[0]) : less(a[0], b[0]) + const ascending = order == "asc" + const result = ascending + ? -less(a[0], b[0], ascending) + : less(a[0], b[0], ascending) if (result !== 0) return result } else { const result = comparator(a[0], b[0]) diff --git a/standalone_app/src/components/TrialTable.tsx b/standalone_app/src/components/TrialTable.tsx index 58d2d101e..97a04f3c0 100644 --- a/standalone_app/src/components/TrialTable.tsx +++ b/standalone_app/src/components/TrialTable.tsx @@ -25,7 +25,7 @@ export const TrialTable: FC<{ field: "values", label: "Value", sortable: true, - less: (firstEl, secondEl): number => { + less: (firstEl, secondEl, ascending): number => { const firstVal = firstEl.values?.[0] const secondVal = secondEl.values?.[0] @@ -33,9 +33,9 @@ export const TrialTable: FC<{ return 0 } if (firstVal === undefined) { - return -1 + return ascending ? -1 : 1 } else if (secondVal === undefined) { - return 1 + return ascending ? 1 : -1 } if (firstVal === "-inf" || secondVal === "inf") { return 1 @@ -57,7 +57,7 @@ export const TrialTable: FC<{ field: "values", label: `Objective ${objectiveId}`, sortable: true, - less: (firstEl, secondEl): number => { + less: (firstEl, secondEl, ascending): number => { const firstVal = firstEl.values?.[objectiveId] const secondVal = secondEl.values?.[objectiveId] @@ -65,9 +65,9 @@ export const TrialTable: FC<{ return 0 } if (firstVal === undefined) { - return -1 + return ascending ? -1 : 1 } else if (secondVal === undefined) { - return 1 + return ascending ? 1 : -1 } if (firstVal === "-inf" || secondVal === "inf") { return 1 @@ -96,7 +96,8 @@ export const TrialTable: FC<{ null, sortable: true, filterable: false, - less: (firstEl, secondEl): number => { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + less: (firstEl, secondEl, _): number => { const firstVal = firstEl.params.find( (p) => p.name === s.name )?.param_internal_value @@ -126,7 +127,8 @@ export const TrialTable: FC<{ ?.value || null, sortable: attr_spec.sortable, filterable: false, - less: (firstEl, secondEl): number => { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + less: (firstEl, secondEl, _): number => { const firstVal = firstEl.user_attrs.find( (attr) => attr.key === attr_spec.key )?.value