Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NelderMead reimplementation #951

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 116 additions & 40 deletions src/Numerics/Optimization/NelderMeadSimplex.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,37 +151,100 @@ public static MinimizationResult Minimum(IObjectiveFunction objectiveFunction, V
break;
}

// attempt a reflection of the simplex
double reflectionPointValue = TryToScaleSimplex(-1.0, ref errorProfile, vertices, errorValues, objectiveFunction);
// This algorithm follows https://www.scilab.org/sites/default/files/neldermead.pdf we give the
// lines from Figure 4.1. to better follow along. Note that the values we use for
// ρ (rho) = 1, χ (chi) =2, γ (gamma) = 0.5 and σ (sigma) = 0.5 are the default values given in the paper and
// match the values used here https://se.mathworks.com/help/matlab/math/optimizing-nonlinear-functions.html#bsgpq6p-11

// calculate the centroid
// x ← x(n + 1)
Vector<double> centroid = ComputeCentroid(vertices, errorProfile);

// attempt a reflection of the simplex - using our default for rho
// x_r ← x(ρ, n + 1) {Reflect}
// f_r ← f(x_r)
(Vector<double> reflectionPoint, double reflectionPointValue) = ScaleSimplex(1.0, ref errorProfile, centroid, vertices, objectiveFunction);
++evaluationCount;
if (reflectionPointValue <= errorValues[errorProfile.LowestIndex])

// if f_r < f_1 then
if (reflectionPointValue < errorValues[errorProfile.LowestIndex])
{
// it's better than the best point, so attempt an expansion of the simplex
TryToScaleSimplex(2.0, ref errorProfile, vertices, errorValues, objectiveFunction);
// it's better than the best point, but we attempt to improve even that by expanding the simplex
// x_e ← x(ρχ, n + 1) {Expand}
// f_e ← f(x_e)
(Vector<double> expansionPoint, double expansionPointValue) = ScaleSimplex(2.0, ref errorProfile, centroid, vertices, objectiveFunction);
++evaluationCount;

// if f_e < f_r then
if (expansionPointValue < reflectionPointValue)
{
// Accept x_e
AcceptNewVertex(expansionPoint, expansionPointValue, ref errorProfile, vertices, errorValues);
}
else
{
// Accept x_r
AcceptNewVertex(reflectionPoint, reflectionPointValue, ref errorProfile, vertices, errorValues);
}
}
// else if f_1 ≤ f_r < f_n then
else if (reflectionPointValue < errorValues[errorProfile.NextHighestIndex])
{
// Accept x_r
AcceptNewVertex(reflectionPoint, reflectionPointValue, ref errorProfile, vertices, errorValues);
}
// else if f_n ≤ f_r < f_n+1 then
else if (reflectionPointValue < errorValues[errorProfile.HighestIndex])
{
// x_c ← x(ργ, n + 1) {Outside contraction}
// f_c ← f(x_c)
(Vector<double> contractionPoint, double contractionPointValue) = ScaleSimplex(0.5, ref errorProfile, centroid, vertices, objectiveFunction);
// if f_c < f_r then
if (contractionPointValue < reflectionPointValue)
{
// Accept x_c
AcceptNewVertex(contractionPoint, contractionPointValue, ref errorProfile, vertices, errorValues);
}
// else
else
{
// Compute the points x_i = x_1 + σ(x_i − x_1), i = 2, n + 1 {Shrink}
// Compute f_i = f(v_i) for i = 2, n + 1
ShrinkSimplex(errorProfile, vertices, errorValues, objectiveFunction);
evaluationCount += numVertices; // that required one function evaluation for each vertex; keep track
}
}
else if (reflectionPointValue >= errorValues[errorProfile.NextHighestIndex])
// else
else
{
// it would be worse than the second best point, so attempt a contraction to look
// for an intermediate point
double currentWorst = errorValues[errorProfile.HighestIndex];
double contractionPointValue = TryToScaleSimplex(0.5, ref errorProfile, vertices, errorValues, objectiveFunction);
// The reflected value is worse than even the worst vertex of the current simplex
// x_c ← x(−γ, n + 1) {Inside contraction}
// f_c ← f(x_c)
(Vector<double> contractionPoint, double contractionPointValue) = ScaleSimplex(-0.5, ref errorProfile, centroid, vertices, objectiveFunction);
++evaluationCount;
if (contractionPointValue >= currentWorst)

// if fc < fn+1 then
if (contractionPointValue < errorValues[errorProfile.HighestIndex])
{
// that would be even worse, so let's try to contract uniformly towards the low point;
// don't bother to update the error profile, we'll do it at the start of the
// next iteration
// Accept x_c
AcceptNewVertex(contractionPoint, contractionPointValue, ref errorProfile, vertices, errorValues);
}
// else
else
{
// Compute the points xi = x_1 + σ(x_i − x_1), i = 2, n + 1 {Shrink}
// Compute fi = f(vi) for i = 2, n + 1
ShrinkSimplex(errorProfile, vertices, errorValues, objectiveFunction);
evaluationCount += numVertices; // that required one function evaluation for each vertex; keep track
}
}
// check to see if we have exceeded our alloted number of evaluations
// check to see if we have exceeded our allotted number of evaluations
if (evaluationCount >= maximumIterations)
{
throw new MaximumIterationsException(FormattableString.Invariant($"Maximum iterations ({maximumIterations}) reached."));
}
}

objectiveFunction.EvaluateAt(vertices[errorProfile.LowestIndex]);
var regressionResult = new MinimizationResult(objectiveFunction, evaluationCount, exitCondition);
return regressionResult;
Expand Down Expand Up @@ -293,38 +356,42 @@ static Vector<double>[] InitializeVertices(SimplexConstant[] simplexConstants)
}

/// <summary>
/// Test a scaling operation of the high point, and replace it if it is an improvement
/// Calculates a new simplex by moving the worst point along the line given by itself and the centroid.
/// </summary>
/// <param name="scaleFactor"></param>
/// <param name="errorProfile"></param>
/// <param name="vertices"></param>
/// <param name="errorValues"></param>
/// <param name="objectiveFunction"></param>
/// <returns></returns>
static double TryToScaleSimplex(double scaleFactor, ref ErrorProfile errorProfile, Vector<double>[] vertices,
double[] errorValues, IObjectiveFunction objectiveFunction)
/// <remarks>This is called the x-function in the paper https://www.scilab.org/sites/default/files/neldermead.pdf (4.4)</remarks>
/// <param name="scaleFactor">The factor to scale along the given line.</param>
/// <param name="errorProfile">The error profile.</param>
/// <param name="centroid">The centroid of the simplex.</param>
/// <param name="vertices">The simplex.</param>
/// <param name="objectiveFunction">The objective function.</param>
/// <returns>The point that would replace the worst thus defining the scaled simplex.</returns>
static (Vector<double> scaledPoint, double scaledValue) ScaleSimplex(double scaleFactor, ref ErrorProfile errorProfile,
Vector<double> centroid, Vector<double>[] vertices, IObjectiveFunction objectiveFunction)
{
// find the centroid through which we will reflect
Vector<double> centroid = ComputeCentroid(vertices, errorProfile);

// define the vector from the centroid to the high point
Vector<double> centroidToHighPoint = vertices[errorProfile.HighestIndex].Subtract(centroid);
// define the vector from the high point to the centroid
Vector<double> highPointToCentroid = centroid.Subtract(vertices[errorProfile.HighestIndex]);

// scale and position the vector to determine the new trial point
Vector<double> newPoint = centroidToHighPoint.Multiply(scaleFactor).Add(centroid);
Vector<double> newPoint = highPointToCentroid.Multiply(scaleFactor).Add(centroid);

// evaluate the new point
objectiveFunction.EvaluateAt(newPoint);
double newErrorValue = objectiveFunction.Value;

// if it's better, replace the old high point
if (newErrorValue < errorValues[errorProfile.HighestIndex])
{
vertices[errorProfile.HighestIndex] = newPoint;
errorValues[errorProfile.HighestIndex] = newErrorValue;
}
return (newPoint, objectiveFunction.Value);
}

return newErrorValue;
/// <summary>
/// Accept the new point as the new vertex of the simplex, replacing the worst point.
/// </summary>
/// <param name="newPoint">The new point.</param>
/// <param name="newErrorValue">The error value at that point.</param>
/// <param name="errorProfile">The error profile.</param>
/// <param name="vertices">The vertices of the simplex.</param>
/// <param name="errorValues">The error values of the simplex.</param>
static void AcceptNewVertex(Vector<double> newPoint, double newErrorValue, ref ErrorProfile errorProfile, Vector<double>[] vertices,
double[] errorValues)
{
vertices[errorProfile.HighestIndex] = newPoint;
errorValues[errorProfile.HighestIndex] = newErrorValue;
}

/// <summary>
Expand All @@ -337,12 +404,21 @@ static double TryToScaleSimplex(double scaleFactor, ref ErrorProfile errorProfil
static void ShrinkSimplex(ErrorProfile errorProfile, Vector<double>[] vertices, double[] errorValues,
IObjectiveFunction objectiveFunction)
{
// Let's try to contract uniformly towards the low point;
// don't bother to update the error profile, we'll do it at the start of the
// next iteration
// In the paper this is written as:
// Compute the points x_i = x_1 + σ(x_i − x_1), i = 2, n + 1 {Shrink}
// Compute f_i = f(v_i) for i = 2, n + 1

Vector<double> lowestVertex = vertices[errorProfile.LowestIndex];
for (int i = 0; i < vertices.Length; i++)
{
if (i != errorProfile.LowestIndex)
{
vertices[i] = (vertices[i].Add(lowestVertex)).Multiply(0.5);
// x_i = x_1 + σ(x_i − x_1) with σ = 1/2 is equal to
// x_i = (x_1 + x_i) / 2
vertices[i] = vertices[i].Add(lowestVertex).Multiply(0.5);
objectiveFunction.EvaluateAt(vertices[i]);
errorValues[i] = objectiveFunction.Value;
}
Expand Down