From 5a356a3823585f2e9fe91d1e7c84e32fc0930713 Mon Sep 17 00:00:00 2001 From: Christoph Hornung Date: Tue, 6 Sep 2022 19:15:33 +0200 Subject: [PATCH 1/2] NelderMead reimplementation --- .../Optimization/NelderMeadSimplex.cs | 157 +++++++++++++----- 1 file changed, 117 insertions(+), 40 deletions(-) diff --git a/src/Numerics/Optimization/NelderMeadSimplex.cs b/src/Numerics/Optimization/NelderMeadSimplex.cs index b903e4921..db52bfff3 100644 --- a/src/Numerics/Optimization/NelderMeadSimplex.cs +++ b/src/Numerics/Optimization/NelderMeadSimplex.cs @@ -30,6 +30,7 @@ // Converted from code released with a MIT license available at https://code.google.com/p/nelder-mead-simplex/ using System; +using System.Reflection; using MathNet.Numerics.LinearAlgebra; namespace MathNet.Numerics.Optimization @@ -151,37 +152,100 @@ public static MinimizationResult Minimum(IObjectiveFunction objectiveFunction, V break; } - // attempt a reflection of the simplex - double reflectionPointValue = TryToScaleSimplex(-1.0, ref errorProfile, vertices, errorValues, objectiveFunction); + // This algorithm follows https://www.scilab.org/sites/default/files/neldermead.pdf we give the + // lines from Figure 4.1. to better follow along. Note that the values we use for + // ρ (rho) = 1, χ (chi) =2, γ (gamma) = 0.5 and σ (sigma) = 0.5 are the default values given in the paper and + // match the values used here https://se.mathworks.com/help/matlab/math/optimizing-nonlinear-functions.html#bsgpq6p-11 + + // calculate the centroid + // x ← x(n + 1) + Vector centroid = ComputeCentroid(vertices, errorProfile); + + // attempt a reflection of the simplex - using our default for rho + // x_r ← x(ρ, n + 1) {Reflect} + // f_r ← f(x_r) + (Vector reflectionPoint, double reflectionPointValue) = ScaleSimplex(1.0, ref errorProfile, centroid, vertices, objectiveFunction); ++evaluationCount; - if (reflectionPointValue <= errorValues[errorProfile.LowestIndex]) + + // if f_r < f_1 then + if (reflectionPointValue < errorValues[errorProfile.LowestIndex]) { - // it's better than the best point, so attempt an expansion of the simplex - TryToScaleSimplex(2.0, ref errorProfile, vertices, errorValues, objectiveFunction); + // it's better than the best point, but we attempt to improve even that by expanding the simplex + // x_e ← x(ρχ, n + 1) {Expand} + // f_e ← f(x_e) + (Vector expansionPoint, double expansionPointValue) = ScaleSimplex(2.0, ref errorProfile, centroid, vertices, objectiveFunction); ++evaluationCount; + + // if f_e < f_r then + if (expansionPointValue < reflectionPointValue) + { + // Accept x_e + AcceptNewVertex(expansionPoint, expansionPointValue, ref errorProfile, vertices, errorValues); + } + else + { + // Accept x_r + AcceptNewVertex(reflectionPoint, reflectionPointValue, ref errorProfile, vertices, errorValues); + } + } + // else if f_1 ≤ f_r < f_n then + else if (reflectionPointValue < errorValues[errorProfile.NextHighestIndex]) + { + // Accept x_r + AcceptNewVertex(reflectionPoint, reflectionPointValue, ref errorProfile, vertices, errorValues); } - else if (reflectionPointValue >= errorValues[errorProfile.NextHighestIndex]) + // else if f_n ≤ f_r < f_n+1 then + else if (reflectionPointValue < errorValues[errorProfile.HighestIndex]) + { + // x_c ← x(ργ, n + 1) {Outside contraction} + // f_c ← f(x_c) + (Vector contractionPoint, double contractionPointValue) = ScaleSimplex(0.5, ref errorProfile, centroid, vertices, objectiveFunction); + // if f_c < f_r then + if (contractionPointValue < reflectionPointValue) + { + // Accept x_c + AcceptNewVertex(contractionPoint, contractionPointValue, ref errorProfile, vertices, errorValues); + } + // else + else + { + // Compute the points x_i = x_1 + σ(x_i − x_1), i = 2, n + 1 {Shrink} + // Compute f_i = f(v_i) for i = 2, n + 1 + ShrinkSimplex(errorProfile, vertices, errorValues, objectiveFunction); + evaluationCount += numVertices; // that required one function evaluation for each vertex; keep track + } + } + // else + else { - // it would be worse than the second best point, so attempt a contraction to look - // for an intermediate point - double currentWorst = errorValues[errorProfile.HighestIndex]; - double contractionPointValue = TryToScaleSimplex(0.5, ref errorProfile, vertices, errorValues, objectiveFunction); + // The reflected value is worse than even the worst vertex of the current simplex + // x_c ← x(−γ, n + 1) {Inside contraction} + // f_c ← f(x_c) + (Vector contractionPoint, double contractionPointValue) = ScaleSimplex(-0.5, ref errorProfile, centroid, vertices, objectiveFunction); ++evaluationCount; - if (contractionPointValue >= currentWorst) + + // if fc < fn+1 then + if (contractionPointValue < errorValues[errorProfile.HighestIndex]) + { + // Accept x_c + AcceptNewVertex(contractionPoint, contractionPointValue, ref errorProfile, vertices, errorValues); + } + // else + else { - // that would be even worse, so let's try to contract uniformly towards the low point; - // don't bother to update the error profile, we'll do it at the start of the - // next iteration + // Compute the points xi = x_1 + σ(x_i − x_1), i = 2, n + 1 {Shrink} + // Compute fi = f(vi) for i = 2, n + 1 ShrinkSimplex(errorProfile, vertices, errorValues, objectiveFunction); evaluationCount += numVertices; // that required one function evaluation for each vertex; keep track } } - // check to see if we have exceeded our alloted number of evaluations + // check to see if we have exceeded our allotted number of evaluations if (evaluationCount >= maximumIterations) { throw new MaximumIterationsException(FormattableString.Invariant($"Maximum iterations ({maximumIterations}) reached.")); } } + objectiveFunction.EvaluateAt(vertices[errorProfile.LowestIndex]); var regressionResult = new MinimizationResult(objectiveFunction, evaluationCount, exitCondition); return regressionResult; @@ -293,38 +357,42 @@ static Vector[] InitializeVertices(SimplexConstant[] simplexConstants) } /// - /// Test a scaling operation of the high point, and replace it if it is an improvement + /// Calculates a new simplex by moving the worst point along the line given by itself and the centroid. /// - /// - /// - /// - /// - /// - /// - static double TryToScaleSimplex(double scaleFactor, ref ErrorProfile errorProfile, Vector[] vertices, - double[] errorValues, IObjectiveFunction objectiveFunction) + /// This is called the x-function in the paper https://www.scilab.org/sites/default/files/neldermead.pdf (4.4) + /// The factor to scale along the given line. + /// The error profile. + /// The centroid of the simplex. + /// The simplex. + /// The objective function. + /// The point that would replace the worst thus defining the scaled simplex. + static (Vector scaledPoint, double scaledValue) ScaleSimplex(double scaleFactor, ref ErrorProfile errorProfile, + Vector centroid, Vector[] vertices, IObjectiveFunction objectiveFunction) { - // find the centroid through which we will reflect - Vector centroid = ComputeCentroid(vertices, errorProfile); - - // define the vector from the centroid to the high point - Vector centroidToHighPoint = vertices[errorProfile.HighestIndex].Subtract(centroid); + // define the vector from the high point to the centroid + Vector highPointToCentroid = centroid.Subtract(vertices[errorProfile.HighestIndex]); // scale and position the vector to determine the new trial point - Vector newPoint = centroidToHighPoint.Multiply(scaleFactor).Add(centroid); + Vector newPoint = highPointToCentroid.Multiply(scaleFactor).Add(centroid); // evaluate the new point objectiveFunction.EvaluateAt(newPoint); - double newErrorValue = objectiveFunction.Value; - - // if it's better, replace the old high point - if (newErrorValue < errorValues[errorProfile.HighestIndex]) - { - vertices[errorProfile.HighestIndex] = newPoint; - errorValues[errorProfile.HighestIndex] = newErrorValue; - } + return (newPoint, objectiveFunction.Value); + } - return newErrorValue; + /// + /// Accept the new point as the new vertex of the simplex, replacing the worst point. + /// + /// The new point. + /// The error value at that point. + /// The error profile. + /// The vertices of the simplex. + /// The error values of the simplex. + static void AcceptNewVertex(Vector newPoint, double newErrorValue, ref ErrorProfile errorProfile, Vector[] vertices, + double[] errorValues) + { + vertices[errorProfile.HighestIndex] = newPoint; + errorValues[errorProfile.HighestIndex] = newErrorValue; } /// @@ -337,12 +405,21 @@ static double TryToScaleSimplex(double scaleFactor, ref ErrorProfile errorProfil static void ShrinkSimplex(ErrorProfile errorProfile, Vector[] vertices, double[] errorValues, IObjectiveFunction objectiveFunction) { + // Let's try to contract uniformly towards the low point; + // don't bother to update the error profile, we'll do it at the start of the + // next iteration + // In the paper this is written as: + // Compute the points x_i = x_1 + σ(x_i − x_1), i = 2, n + 1 {Shrink} + // Compute f_i = f(v_i) for i = 2, n + 1 + Vector lowestVertex = vertices[errorProfile.LowestIndex]; for (int i = 0; i < vertices.Length; i++) { if (i != errorProfile.LowestIndex) { - vertices[i] = (vertices[i].Add(lowestVertex)).Multiply(0.5); + // x_i = x_1 + σ(x_i − x_1) with σ = 1/2 is equal to + // x_i = (x_1 + x_i) / 2 + vertices[i] = vertices[i].Add(lowestVertex).Multiply(0.5); objectiveFunction.EvaluateAt(vertices[i]); errorValues[i] = objectiveFunction.Value; } From 51c536b1a6aa69bfbedafa8f8a920fc4a8347463 Mon Sep 17 00:00:00 2001 From: Christoph Hornung Date: Tue, 6 Sep 2022 19:19:12 +0200 Subject: [PATCH 2/2] Fixed using --- src/Numerics/Optimization/NelderMeadSimplex.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Numerics/Optimization/NelderMeadSimplex.cs b/src/Numerics/Optimization/NelderMeadSimplex.cs index db52bfff3..542ca90a2 100644 --- a/src/Numerics/Optimization/NelderMeadSimplex.cs +++ b/src/Numerics/Optimization/NelderMeadSimplex.cs @@ -30,7 +30,6 @@ // Converted from code released with a MIT license available at https://code.google.com/p/nelder-mead-simplex/ using System; -using System.Reflection; using MathNet.Numerics.LinearAlgebra; namespace MathNet.Numerics.Optimization