diff --git a/src/Numerics/Optimization/NelderMeadSimplex.cs b/src/Numerics/Optimization/NelderMeadSimplex.cs index b903e4921..542ca90a2 100644 --- a/src/Numerics/Optimization/NelderMeadSimplex.cs +++ b/src/Numerics/Optimization/NelderMeadSimplex.cs @@ -151,37 +151,100 @@ public static MinimizationResult Minimum(IObjectiveFunction objectiveFunction, V break; } - // attempt a reflection of the simplex - double reflectionPointValue = TryToScaleSimplex(-1.0, ref errorProfile, vertices, errorValues, objectiveFunction); + // This algorithm follows https://www.scilab.org/sites/default/files/neldermead.pdf we give the + // lines from Figure 4.1. to better follow along. Note that the values we use for + // ρ (rho) = 1, χ (chi) =2, γ (gamma) = 0.5 and σ (sigma) = 0.5 are the default values given in the paper and + // match the values used here https://se.mathworks.com/help/matlab/math/optimizing-nonlinear-functions.html#bsgpq6p-11 + + // calculate the centroid + // x ← x(n + 1) + Vector centroid = ComputeCentroid(vertices, errorProfile); + + // attempt a reflection of the simplex - using our default for rho + // x_r ← x(ρ, n + 1) {Reflect} + // f_r ← f(x_r) + (Vector reflectionPoint, double reflectionPointValue) = ScaleSimplex(1.0, ref errorProfile, centroid, vertices, objectiveFunction); ++evaluationCount; - if (reflectionPointValue <= errorValues[errorProfile.LowestIndex]) + + // if f_r < f_1 then + if (reflectionPointValue < errorValues[errorProfile.LowestIndex]) { - // it's better than the best point, so attempt an expansion of the simplex - TryToScaleSimplex(2.0, ref errorProfile, vertices, errorValues, objectiveFunction); + // it's better than the best point, but we attempt to improve even that by expanding the simplex + // x_e ← x(ρχ, n + 1) {Expand} + // f_e ← f(x_e) + (Vector expansionPoint, double expansionPointValue) = ScaleSimplex(2.0, ref errorProfile, centroid, vertices, objectiveFunction); ++evaluationCount; + + // if f_e < f_r then + if (expansionPointValue < reflectionPointValue) + { + // Accept x_e + AcceptNewVertex(expansionPoint, expansionPointValue, ref errorProfile, vertices, errorValues); + } + else + { + // Accept x_r + AcceptNewVertex(reflectionPoint, reflectionPointValue, ref errorProfile, vertices, errorValues); + } + } + // else if f_1 ≤ f_r < f_n then + else if (reflectionPointValue < errorValues[errorProfile.NextHighestIndex]) + { + // Accept x_r + AcceptNewVertex(reflectionPoint, reflectionPointValue, ref errorProfile, vertices, errorValues); + } + // else if f_n ≤ f_r < f_n+1 then + else if (reflectionPointValue < errorValues[errorProfile.HighestIndex]) + { + // x_c ← x(ργ, n + 1) {Outside contraction} + // f_c ← f(x_c) + (Vector contractionPoint, double contractionPointValue) = ScaleSimplex(0.5, ref errorProfile, centroid, vertices, objectiveFunction); + // if f_c < f_r then + if (contractionPointValue < reflectionPointValue) + { + // Accept x_c + AcceptNewVertex(contractionPoint, contractionPointValue, ref errorProfile, vertices, errorValues); + } + // else + else + { + // Compute the points x_i = x_1 + σ(x_i − x_1), i = 2, n + 1 {Shrink} + // Compute f_i = f(v_i) for i = 2, n + 1 + ShrinkSimplex(errorProfile, vertices, errorValues, objectiveFunction); + evaluationCount += numVertices; // that required one function evaluation for each vertex; keep track + } } - else if (reflectionPointValue >= errorValues[errorProfile.NextHighestIndex]) + // else + else { - // it would be worse than the second best point, so attempt a contraction to look - // for an intermediate point - double currentWorst = errorValues[errorProfile.HighestIndex]; - double contractionPointValue = TryToScaleSimplex(0.5, ref errorProfile, vertices, errorValues, objectiveFunction); + // The reflected value is worse than even the worst vertex of the current simplex + // x_c ← x(−γ, n + 1) {Inside contraction} + // f_c ← f(x_c) + (Vector contractionPoint, double contractionPointValue) = ScaleSimplex(-0.5, ref errorProfile, centroid, vertices, objectiveFunction); ++evaluationCount; - if (contractionPointValue >= currentWorst) + + // if fc < fn+1 then + if (contractionPointValue < errorValues[errorProfile.HighestIndex]) { - // that would be even worse, so let's try to contract uniformly towards the low point; - // don't bother to update the error profile, we'll do it at the start of the - // next iteration + // Accept x_c + AcceptNewVertex(contractionPoint, contractionPointValue, ref errorProfile, vertices, errorValues); + } + // else + else + { + // Compute the points xi = x_1 + σ(x_i − x_1), i = 2, n + 1 {Shrink} + // Compute fi = f(vi) for i = 2, n + 1 ShrinkSimplex(errorProfile, vertices, errorValues, objectiveFunction); evaluationCount += numVertices; // that required one function evaluation for each vertex; keep track } } - // check to see if we have exceeded our alloted number of evaluations + // check to see if we have exceeded our allotted number of evaluations if (evaluationCount >= maximumIterations) { throw new MaximumIterationsException(FormattableString.Invariant($"Maximum iterations ({maximumIterations}) reached.")); } } + objectiveFunction.EvaluateAt(vertices[errorProfile.LowestIndex]); var regressionResult = new MinimizationResult(objectiveFunction, evaluationCount, exitCondition); return regressionResult; @@ -293,38 +356,42 @@ static Vector[] InitializeVertices(SimplexConstant[] simplexConstants) } /// - /// Test a scaling operation of the high point, and replace it if it is an improvement + /// Calculates a new simplex by moving the worst point along the line given by itself and the centroid. /// - /// - /// - /// - /// - /// - /// - static double TryToScaleSimplex(double scaleFactor, ref ErrorProfile errorProfile, Vector[] vertices, - double[] errorValues, IObjectiveFunction objectiveFunction) + /// This is called the x-function in the paper https://www.scilab.org/sites/default/files/neldermead.pdf (4.4) + /// The factor to scale along the given line. + /// The error profile. + /// The centroid of the simplex. + /// The simplex. + /// The objective function. + /// The point that would replace the worst thus defining the scaled simplex. + static (Vector scaledPoint, double scaledValue) ScaleSimplex(double scaleFactor, ref ErrorProfile errorProfile, + Vector centroid, Vector[] vertices, IObjectiveFunction objectiveFunction) { - // find the centroid through which we will reflect - Vector centroid = ComputeCentroid(vertices, errorProfile); - - // define the vector from the centroid to the high point - Vector centroidToHighPoint = vertices[errorProfile.HighestIndex].Subtract(centroid); + // define the vector from the high point to the centroid + Vector highPointToCentroid = centroid.Subtract(vertices[errorProfile.HighestIndex]); // scale and position the vector to determine the new trial point - Vector newPoint = centroidToHighPoint.Multiply(scaleFactor).Add(centroid); + Vector newPoint = highPointToCentroid.Multiply(scaleFactor).Add(centroid); // evaluate the new point objectiveFunction.EvaluateAt(newPoint); - double newErrorValue = objectiveFunction.Value; - - // if it's better, replace the old high point - if (newErrorValue < errorValues[errorProfile.HighestIndex]) - { - vertices[errorProfile.HighestIndex] = newPoint; - errorValues[errorProfile.HighestIndex] = newErrorValue; - } + return (newPoint, objectiveFunction.Value); + } - return newErrorValue; + /// + /// Accept the new point as the new vertex of the simplex, replacing the worst point. + /// + /// The new point. + /// The error value at that point. + /// The error profile. + /// The vertices of the simplex. + /// The error values of the simplex. + static void AcceptNewVertex(Vector newPoint, double newErrorValue, ref ErrorProfile errorProfile, Vector[] vertices, + double[] errorValues) + { + vertices[errorProfile.HighestIndex] = newPoint; + errorValues[errorProfile.HighestIndex] = newErrorValue; } /// @@ -337,12 +404,21 @@ static double TryToScaleSimplex(double scaleFactor, ref ErrorProfile errorProfil static void ShrinkSimplex(ErrorProfile errorProfile, Vector[] vertices, double[] errorValues, IObjectiveFunction objectiveFunction) { + // Let's try to contract uniformly towards the low point; + // don't bother to update the error profile, we'll do it at the start of the + // next iteration + // In the paper this is written as: + // Compute the points x_i = x_1 + σ(x_i − x_1), i = 2, n + 1 {Shrink} + // Compute f_i = f(v_i) for i = 2, n + 1 + Vector lowestVertex = vertices[errorProfile.LowestIndex]; for (int i = 0; i < vertices.Length; i++) { if (i != errorProfile.LowestIndex) { - vertices[i] = (vertices[i].Add(lowestVertex)).Multiply(0.5); + // x_i = x_1 + σ(x_i − x_1) with σ = 1/2 is equal to + // x_i = (x_1 + x_i) / 2 + vertices[i] = vertices[i].Add(lowestVertex).Multiply(0.5); objectiveFunction.EvaluateAt(vertices[i]); errorValues[i] = objectiveFunction.Value; }