Skip to content

Commit 5988019

Browse files
committed
Cleaned up more default parameters and removed unused options since interface implementation will take care of them
1 parent d7053f7 commit 5988019

File tree

4 files changed

+34
-43
lines changed

4 files changed

+34
-43
lines changed

src/Interfaces/IOptimizationAlgorithm.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@ public OptimizationResult<T> Optimize(
1010
Matrix<T> XTest,
1111
Vector<T> yTest,
1212
PredictionModelOptions modelOptions,
13-
OptimizationAlgorithmOptions optimizationOptions,
1413
IRegression<T> regressionMethod,
1514
IRegularization<T> regularization,
1615
INormalizer<T> normalizer,
1716
NormalizationInfo<T> normInfo,
1817
IFitnessCalculator<T> fitnessCalculator,
1918
IFitDetector<T> fitDetector);
2019

21-
bool ShouldEarlyStop(List<OptimizationIterationInfo<T>> iterationHistory, OptimizationAlgorithmOptions options, IFitnessCalculator<T> fitnessCalculator);
20+
bool ShouldEarlyStop(List<OptimizationIterationInfo<T>> iterationHistory, IFitnessCalculator<T> fitnessCalculator);
2221
}

src/Interfaces/IPredictionModelBuilder.cs

+9-9
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
public interface IPredictionModelBuilder<T>
44
{
5-
IPredictionModelBuilder<T> WithFeatureSelector(IFeatureSelector<T> selector);
6-
IPredictionModelBuilder<T> WithNormalizer(INormalizer<T> normalizer);
7-
IPredictionModelBuilder<T> WithRegularization(IRegularization<T> regularization, RegularizationOptions? regularizationOptions = null);
8-
IPredictionModelBuilder<T> WithFitnessCalculator(IFitnessCalculator<T> calculator, FitnessCalculatorOptions? _fitnessCalculatorOptions = null);
9-
IPredictionModelBuilder<T> WithFitDetector(IFitDetector<T> detector);
10-
IPredictionModelBuilder<T> WithRegression(IRegression<T> regression, RegressionOptions<T>? regressionOptions = null);
11-
IPredictionModelBuilder<T> WithOptimizer(IOptimizationAlgorithm<T> optimizationAlgorithm, OptimizationAlgorithmOptions? optimizationOptions = null);
12-
IPredictionModelBuilder<T> WithDataPreprocessor(IDataPreprocessor<T> dataPreprocessor);
13-
IPredictionModelBuilder<T> WithOutlierRemoval(IOutlierRemoval<T> outlierRemoval);
5+
IPredictionModelBuilder<T> ConfigureFeatureSelector(IFeatureSelector<T> selector);
6+
IPredictionModelBuilder<T> ConfigureNormalizer(INormalizer<T> normalizer);
7+
IPredictionModelBuilder<T> ConfigureRegularization(IRegularization<T> regularization);
8+
IPredictionModelBuilder<T> ConfigureFitnessCalculator(IFitnessCalculator<T> calculator);
9+
IPredictionModelBuilder<T> ConfigureFitDetector(IFitDetector<T> detector);
10+
IPredictionModelBuilder<T> ConfigureRegression(IRegression<T> regression);
11+
IPredictionModelBuilder<T> ConfigureOptimizer(IOptimizationAlgorithm<T> optimizationAlgorithm);
12+
IPredictionModelBuilder<T> ConfigureDataPreprocessor(IDataPreprocessor<T> dataPreprocessor);
13+
IPredictionModelBuilder<T> ConfigureOutlierRemoval(IOutlierRemoval<T> outlierRemoval);
1414
PredictionModelResult<T> Build(Matrix<T> x, Vector<T> y);
1515
Vector<T> Predict(Matrix<T> newData, PredictionModelResult<T> model);
1616
void SaveModel(PredictionModelResult<T> model, string filePath);

src/LinearAlgebra/PredictionModelBuilder.cs

+10-19
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@ namespace AiDotNet.LinearAlgebra;
1010
public class PredictionModelBuilder<T> : IPredictionModelBuilder<T>
1111
{
1212
private readonly PredictionModelOptions _options;
13-
private OptimizationAlgorithmOptions? _optimizationOptions;
14-
private RegularizationOptions? _regularizationOptions;
15-
private RegressionOptions<T>? _regressionOptions;
16-
private FitnessCalculatorOptions? _fitnessCalculatorOptions;
1713
private IFeatureSelector<T>? _featureSelector;
1814
private INormalizer<T>? _normalizer;
1915
private IRegularization<T>? _regularization;
@@ -29,59 +25,55 @@ public PredictionModelBuilder(PredictionModelOptions? options = null)
2925
_options = options ?? new PredictionModelOptions();
3026
}
3127

32-
public IPredictionModelBuilder<T> WithFeatureSelector(IFeatureSelector<T> selector)
28+
public IPredictionModelBuilder<T> ConfigureFeatureSelector(IFeatureSelector<T> selector)
3329
{
3430
_featureSelector = selector;
3531
return this;
3632
}
3733

38-
public IPredictionModelBuilder<T> WithNormalizer(INormalizer<T> normalizer)
34+
public IPredictionModelBuilder<T> ConfigureNormalizer(INormalizer<T> normalizer)
3935
{
4036
_normalizer = normalizer;
4137
return this;
4238
}
4339

44-
public IPredictionModelBuilder<T> WithRegularization(IRegularization<T> regularization, RegularizationOptions? regularizationOptions = null)
40+
public IPredictionModelBuilder<T> ConfigureRegularization(IRegularization<T> regularization)
4541
{
4642
_regularization = regularization;
47-
_regularizationOptions = regularizationOptions;
4843
return this;
4944
}
5045

51-
public IPredictionModelBuilder<T> WithFitnessCalculator(IFitnessCalculator<T> calculator, FitnessCalculatorOptions? fitnessCalculatorOptions = null)
46+
public IPredictionModelBuilder<T> ConfigureFitnessCalculator(IFitnessCalculator<T> calculator)
5247
{
5348
_fitnessCalculator = calculator;
54-
_fitnessCalculatorOptions = fitnessCalculatorOptions;
5549
return this;
5650
}
5751

58-
public IPredictionModelBuilder<T> WithFitDetector(IFitDetector<T> detector)
52+
public IPredictionModelBuilder<T> ConfigureFitDetector(IFitDetector<T> detector)
5953
{
6054
_fitDetector = detector;
6155
return this;
6256
}
6357

64-
public IPredictionModelBuilder<T> WithRegression(IRegression<T> regression, RegressionOptions<T>? regressionOptions = null)
58+
public IPredictionModelBuilder<T> ConfigureRegression(IRegression<T> regression)
6559
{
6660
_regression = regression;
67-
_regressionOptions = regressionOptions;
6861
return this;
6962
}
7063

71-
public IPredictionModelBuilder<T> WithOptimizer(IOptimizationAlgorithm<T> optimizationAlgorithm, OptimizationAlgorithmOptions? optimizationOptions = null)
64+
public IPredictionModelBuilder<T> ConfigureOptimizer(IOptimizationAlgorithm<T> optimizationAlgorithm)
7265
{
7366
_optimizer = optimizationAlgorithm;
74-
_optimizationOptions = optimizationOptions;
7567
return this;
7668
}
7769

78-
public IPredictionModelBuilder<T> WithDataPreprocessor(IDataPreprocessor<T> dataPreprocessor)
70+
public IPredictionModelBuilder<T> ConfigureDataPreprocessor(IDataPreprocessor<T> dataPreprocessor)
7971
{
8072
_dataPreprocessor = dataPreprocessor;
8173
return this;
8274
}
8375

84-
public IPredictionModelBuilder<T> WithOutlierRemoval(IOutlierRemoval<T> outlierRemoval)
76+
public IPredictionModelBuilder<T> ConfigureOutlierRemoval(IOutlierRemoval<T> outlierRemoval)
8577
{
8678
_outlierRemoval = outlierRemoval;
8779
return this;
@@ -102,7 +94,6 @@ public PredictionModelResult<T> Build(Matrix<T> x, Vector<T> y)
10294
// Use defaults for these interfaces if they aren't set
10395
var normalizer = _normalizer ?? new NoNormalizer<T>();
10496
var optimizer = _optimizer ?? new NormalOptimizer<T>();
105-
var optimizerOptions = _optimizationOptions ?? new OptimizationAlgorithmOptions();
10697
var featureSelector = _featureSelector ?? new NoFeatureSelector<T>();
10798
var fitDetector = _fitDetector ?? new DefaultFitDetector<T>();
10899
var fitnessCalculator = _fitnessCalculator ?? new RSquaredFitnessCalculator<T>();
@@ -117,7 +108,7 @@ public PredictionModelResult<T> Build(Matrix<T> x, Vector<T> y)
117108
var (XTrain, yTrain, XVal, yVal, XTest, yTest) = dataPreprocessor.SplitData(preprocessedX, preprocessedY);
118109

119110
// Optimize the model
120-
var optimizationResult = optimizer.Optimize(XTrain, yTrain, XVal, yVal, XTest, yTest, _options, optimizerOptions, _regression, regularization, normalizer,
111+
var optimizationResult = optimizer.Optimize(XTrain, yTrain, XVal, yVal, XTest, yTest, _options, _regression, regularization, normalizer,
121112
normInfo, fitnessCalculator, fitDetector);
122113

123114
return new PredictionModelResult<T>

src/Optimizers/NormalOptimizer.cs

+14-13
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ public class NormalOptimizer<T> : IOptimizationAlgorithm<T>
44
{
55
private readonly Random _random = new();
66
private readonly INumericOperations<T> _numOps;
7+
private readonly OptimizationAlgorithmOptions _optimizationOptions;
78

8-
public NormalOptimizer(INumericOperations<T>? numericOperations = null)
9+
public NormalOptimizer(OptimizationAlgorithmOptions? optimizationOptions = null)
910
{
10-
_numOps = numericOperations ?? MathHelper.GetNumericOperations<T>();
11+
_numOps = MathHelper.GetNumericOperations<T>();
12+
_optimizationOptions = optimizationOptions ?? new OptimizationAlgorithmOptions();
1113
}
1214

1315
public OptimizationResult<T> Optimize(
@@ -18,7 +20,6 @@ public OptimizationResult<T> Optimize(
1820
Matrix<T> XTest,
1921
Vector<T> yTest,
2022
PredictionModelOptions modelOptions,
21-
OptimizationAlgorithmOptions optimizationOptions,
2223
IRegression<T> regressionMethod,
2324
IRegularization<T> regularization,
2425
INormalizer<T> normalizer,
@@ -28,7 +29,7 @@ public OptimizationResult<T> Optimize(
2829
{
2930
var bestSolution = new Vector<T>(XTrain.Columns, _numOps);
3031
var bestIntercept = _numOps.Zero;
31-
T bestFitness = optimizationOptions.MaximizeFitness ? _numOps.MinValue : _numOps.MaxValue;
32+
T bestFitness = _optimizationOptions.MaximizeFitness ? _numOps.MinValue : _numOps.MaxValue;
3233
var fitnessHistory = new List<T>();
3334
var iterationHistory = new List<OptimizationIterationInfo<T>>();
3435
var bestSelectedFeatures = new List<Vector<T>>();
@@ -52,7 +53,7 @@ public OptimizationResult<T> Optimize(
5253
PredictionStats<T>? bestValidationPredictionStats = null;
5354
PredictionStats<T>? bestTestPredictionStats = null;
5455

55-
for (int iteration = 0; iteration < optimizationOptions.MaxIterations; iteration++)
56+
for (int iteration = 0; iteration < _optimizationOptions.MaxIterations; iteration++)
5657
{
5758
// Randomly select features
5859
var selectedFeatures = RandomlySelectFeatures(XTrain.Columns, modelOptions.MinimumFeatures, modelOptions.MaximumFeatures);
@@ -87,9 +88,9 @@ public OptimizationResult<T> Optimize(
8788
var testActualBasicStats = new BasicStats<T>(yTest);
8889
var testPredictedBasicStats = new BasicStats<T>(testPredictions);
8990

90-
var trainingPredictionStats = new PredictionStats<T>(yTrain, trainingPredictions, featureCount, _numOps.FromDouble(optimizationOptions.ConfidenceLevel), _numOps);
91-
var validationPredictionStats = new PredictionStats<T>(yVal, validationPredictions, featureCount, _numOps.FromDouble(optimizationOptions.ConfidenceLevel), _numOps);
92-
var testPredictionStats = new PredictionStats<T>(yTest, testPredictions, featureCount, _numOps.FromDouble(optimizationOptions.ConfidenceLevel), _numOps);
91+
var trainingPredictionStats = new PredictionStats<T>(yTrain, trainingPredictions, featureCount, _numOps.FromDouble(_optimizationOptions.ConfidenceLevel), _numOps);
92+
var validationPredictionStats = new PredictionStats<T>(yVal, validationPredictions, featureCount, _numOps.FromDouble(_optimizationOptions.ConfidenceLevel), _numOps);
93+
var testPredictionStats = new PredictionStats<T>(yTest, testPredictions, featureCount, _numOps.FromDouble(_optimizationOptions.ConfidenceLevel), _numOps);
9394

9495
// Detect fit type
9596
var fitDetectionResult = fitDetector.DetectFit(
@@ -141,7 +142,7 @@ public OptimizationResult<T> Optimize(
141142
});
142143

143144
// Check for early stopping
144-
if (optimizationOptions.UseEarlyStopping && ShouldEarlyStop(iterationHistory, optimizationOptions, fitnessCalculator))
145+
if (_optimizationOptions.UseEarlyStopping && ShouldEarlyStop(iterationHistory, fitnessCalculator))
145146
{
146147
break;
147148
}
@@ -194,14 +195,14 @@ public OptimizationResult<T> Optimize(
194195
};
195196
}
196197

197-
public bool ShouldEarlyStop(List<OptimizationIterationInfo<T>> iterationHistory, OptimizationAlgorithmOptions options, IFitnessCalculator<T> fitnessCalculator)
198+
public bool ShouldEarlyStop(List<OptimizationIterationInfo<T>> iterationHistory, IFitnessCalculator<T> fitnessCalculator)
198199
{
199-
if (iterationHistory.Count < options.EarlyStoppingPatience)
200+
if (iterationHistory.Count < _optimizationOptions.EarlyStoppingPatience)
200201
{
201202
return false;
202203
}
203204

204-
var recentIterations = iterationHistory.Skip(Math.Max(0, iterationHistory.Count - options.EarlyStoppingPatience)).ToList();
205+
var recentIterations = iterationHistory.Skip(Math.Max(0, iterationHistory.Count - _optimizationOptions.EarlyStoppingPatience)).ToList();
205206

206207
// Find the best fitness score
207208
T bestFitness = iterationHistory[0].Fitness;
@@ -238,7 +239,7 @@ public bool ShouldEarlyStop(List<OptimizationIterationInfo<T>> iterationHistory,
238239
}
239240
}
240241

241-
return noImprovement || consecutiveBadFits >= options.BadFitPatience;
242+
return noImprovement || consecutiveBadFits >= _optimizationOptions.BadFitPatience;
242243
}
243244

244245
private List<int> RandomlySelectFeatures(int totalFeatures, int minFeatures, int maxFeatures)

0 commit comments

Comments
 (0)