Skip to content

Commit d7053f7

Browse files
committed
Set default values for regularization and outlier removal so user isn't forced to provide those parameters
1 parent 6bad169 commit d7053f7

8 files changed

+38
-22
lines changed

src/Models/RegularizationOptions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ public class RegularizationOptions
44
{
55
public RegularizationType Type { get; set; } = RegularizationType.None;
66
public double Strength { get; set; } = 0.0;
7-
public double L1Ratio { get; set; } = 0.5; // Only used for ElasticNet
7+
public double L1Ratio { get; set; } = 0.5;
88
}

src/OutlierRemoval/IQROutlierRemoval.cs

+7-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ public class IQROutlierRemoval<T> : IOutlierRemoval<T>
55
private readonly T _iqrMultiplier;
66
private readonly INumericOperations<T> _numOps;
77

8-
public IQROutlierRemoval(T iqrMultiplier)
8+
public IQROutlierRemoval(T? iqrMultiplier = default)
99
{
10-
_iqrMultiplier = iqrMultiplier;
1110
_numOps = MathHelper.GetNumericOperations<T>();
11+
_iqrMultiplier = iqrMultiplier ?? GetDefaultMultiplier();
1212
}
1313

1414
public (Matrix<T> CleanedInputs, Vector<T> CleanedOutputs) RemoveOutliers(Matrix<T> inputs, Vector<T> outputs)
@@ -45,4 +45,9 @@ public IQROutlierRemoval(T iqrMultiplier)
4545

4646
return (new Matrix<T>(cleanedInputs, _numOps), new Vector<T>(cleanedOutputs, _numOps));
4747
}
48+
49+
private T GetDefaultMultiplier()
50+
{
51+
return _numOps.FromDouble(1.5);
52+
}
4853
}

src/OutlierRemoval/MADOutlierRemoval.cs

+7-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ public class MADOutlierRemoval<T> : IOutlierRemoval<T>
55
private readonly T _threshold;
66
private readonly INumericOperations<T> _numOps;
77

8-
public MADOutlierRemoval(T threshold)
8+
public MADOutlierRemoval(T? threshold = default)
99
{
10-
_threshold = threshold;
1110
_numOps = MathHelper.GetNumericOperations<T>();
11+
_threshold = threshold ?? GetDefaultThreshold();
1212
}
1313

1414
public (Matrix<T> CleanedInputs, Vector<T> CleanedOutputs) RemoveOutliers(Matrix<T> inputs, Vector<T> outputs)
@@ -43,4 +43,9 @@ public MADOutlierRemoval(T threshold)
4343

4444
return (new Matrix<T>(cleanedInputs, _numOps), new Vector<T>(cleanedOutputs, _numOps));
4545
}
46+
47+
private T GetDefaultThreshold()
48+
{
49+
return _numOps.FromDouble(3.5);
50+
}
4651
}

src/OutlierRemoval/ThresholdOutlierRemoval.cs

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
namespace AiDotNet.OutlierRemoval;
22

3-
/// <summary>
4-
/// Removes outliers from the data using the threshold method. This method is not recommended for data sets with less than 15 data points.
5-
/// </summary>
63
public class ThresholdOutlierRemoval<T> : IOutlierRemoval<T>
74
{
85
private readonly T _threshold;
96
private readonly INumericOperations<T> _numOps;
107

11-
public ThresholdOutlierRemoval(T threshold)
8+
public ThresholdOutlierRemoval(T? threshold = default)
129
{
13-
_threshold = threshold;
1410
_numOps = MathHelper.GetNumericOperations<T>();
11+
_threshold = threshold ?? GetDefaultThreshold();
1512
}
1613

1714
public (Matrix<T> CleanedInputs, Vector<T> CleanedOutputs) RemoveOutliers(Matrix<T> inputs, Vector<T> outputs)
@@ -42,4 +39,9 @@ public ThresholdOutlierRemoval(T threshold)
4239

4340
return (new Matrix<T>(cleanedInputs, _numOps), new Vector<T>(cleanedOutputs, _numOps));
4441
}
42+
43+
private T GetDefaultThreshold()
44+
{
45+
return _numOps.FromDouble(3.0);
46+
}
4547
}

src/OutlierRemoval/ZScoreOutlierRemoval.cs

+7-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ public class ZScoreOutlierRemoval<T> : IOutlierRemoval<T>
55
private readonly T _threshold;
66
private readonly INumericOperations<T> _numOps;
77

8-
public ZScoreOutlierRemoval(T threshold)
8+
public ZScoreOutlierRemoval(T? threshold = default)
99
{
10-
_threshold = threshold;
1110
_numOps = MathHelper.GetNumericOperations<T>();
11+
_threshold = threshold ?? FindDefaultThreshold();
1212
}
1313

1414
public (Matrix<T> CleanedInputs, Vector<T> CleanedOutputs) RemoveOutliers(Matrix<T> inputs, Vector<T> outputs)
@@ -41,4 +41,9 @@ public ZScoreOutlierRemoval(T threshold)
4141

4242
return (new Matrix<T>(cleanedInputs, _numOps), new Vector<T>(cleanedOutputs, _numOps));
4343
}
44+
45+
private T FindDefaultThreshold()
46+
{
47+
return _numOps.FromDouble(3.0);
48+
}
4449
}

src/Regularization/ElasticRegularization.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ public class ElasticNetRegularization<T> : IRegularization<T>
55
private readonly INumericOperations<T> _numOps;
66
private readonly RegularizationOptions _options;
77

8-
public ElasticNetRegularization(INumericOperations<T> numOps, RegularizationOptions options)
8+
public ElasticNetRegularization(RegularizationOptions? options = null)
99
{
10-
_numOps = numOps;
11-
_options = options;
10+
_numOps = MathHelper.GetNumericOperations<T>();
11+
_options = options ?? new RegularizationOptions();
1212
}
1313

1414
public Matrix<T> RegularizeMatrix(Matrix<T> matrix)

src/Regularization/L1Regularization.cs

+3-4
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@ public class L1Regularization<T> : IRegularization<T>
55
private readonly INumericOperations<T> _numOps;
66
private readonly RegularizationOptions _options;
77

8-
public L1Regularization(INumericOperations<T> numOps, RegularizationOptions options)
8+
public L1Regularization(RegularizationOptions? options = null)
99
{
10-
_numOps = numOps;
11-
_options = options;
10+
_numOps = MathHelper.GetNumericOperations<T>();
11+
_options = options ?? new RegularizationOptions();
1212
}
1313

1414
public Matrix<T> RegularizeMatrix(Matrix<T> matrix)
1515
{
16-
// L1 regularization doesn't modify the matrix directly
1716
return matrix;
1817
}
1918

src/Regularization/L2Regularization.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ public class L2Regularization<T> : IRegularization<T>
55
private readonly INumericOperations<T> _numOps;
66
private readonly RegularizationOptions _options;
77

8-
public L2Regularization(INumericOperations<T> numOps, RegularizationOptions options)
8+
public L2Regularization(RegularizationOptions? options = null)
99
{
10-
_numOps = numOps;
11-
_options = options;
10+
_numOps = MathHelper.GetNumericOperations<T>();
11+
_options = options ?? new RegularizationOptions();
1212
}
1313

1414
public Matrix<T> RegularizeMatrix(Matrix<T> matrix)

0 commit comments

Comments
 (0)