|
| 1 | +namespace Common |
| 2 | + |
| 3 | +module ConsoleHelper = |
| 4 | + open System |
| 5 | + open Microsoft.ML |
| 6 | + open Microsoft.ML.Runtime.Data |
| 7 | + open Microsoft.ML.Data |
| 8 | + open Microsoft.ML.Core.Data |
| 9 | + open Microsoft.ML.Runtime.Api |
| 10 | + open System.Reflection |
| 11 | + |
| 12 | + let printPrediction prediction = |
| 13 | + printfn "*************************************************" |
| 14 | + printfn "Predicted : %s" prediction |
| 15 | + printfn "*************************************************" |
| 16 | + |
| 17 | + let printRegressionPredictionVersusObserved predictionCount observedCount = |
| 18 | + printfn "-------------------------------------------------" |
| 19 | + printfn "Predicted : %d" predictionCount |
| 20 | + printfn "Actual: %s" observedCount |
| 21 | + printfn "-------------------------------------------------" |
| 22 | + |
| 23 | + let printRegressionMetrics name (metrics : RegressionEvaluator.Result) = |
| 24 | + printfn "*************************************************" |
| 25 | + printfn "* Metrics for %s regression model " name |
| 26 | + printfn "*------------------------------------------------" |
| 27 | + printfn "* LossFn: %.2f" metrics.LossFn |
| 28 | + printfn "* R2 Score: %.2f" metrics.RSquared |
| 29 | + printfn "* Absolute loss: %.2f" metrics.L1 |
| 30 | + printfn "* Squared loss: %.2f" metrics.L2 |
| 31 | + printfn "* RMS loss: %.2f" metrics.Rms |
| 32 | + printfn "*************************************************" |
| 33 | + |
| 34 | + let printBinaryClassificationMetrics name (metrics : BinaryClassifierEvaluator.Result) = |
| 35 | + printfn"************************************************************" |
| 36 | + printfn"* Metrics for %s binary classification model " name |
| 37 | + printfn"*-----------------------------------------------------------" |
| 38 | + printfn"* Accuracy: %.2f%%" (metrics.Accuracy * 100.) |
| 39 | + printfn"* Auc: %.2f%%" (metrics.Auc * 100.) |
| 40 | + printfn"* F1Score: %.2f%%" (metrics.F1Score * 100.) |
| 41 | + printfn"************************************************************" |
| 42 | + |
| 43 | + let printMultiClassClassificationMetrics name (metrics : MultiClassClassifierEvaluator.Result) = |
| 44 | + printfn "************************************************************" |
| 45 | + printfn "* Metrics for %s multi-class classification model " name |
| 46 | + printfn "*-----------------------------------------------------------" |
| 47 | + printfn " AccuracyMacro = %.4f, a value between 0 and 1, the closer to 1, the better" metrics.AccuracyMacro |
| 48 | + printfn " AccuracyMicro = %.4f, a value between 0 and 1, the closer to 1, the better" metrics.AccuracyMicro |
| 49 | + printfn " LogLoss = %.4f, the closer to 0, the better" metrics.LogLoss |
| 50 | + printfn " LogLoss for class 1 = %.4f, the closer to 0, the better" metrics.PerClassLogLoss.[0] |
| 51 | + printfn " LogLoss for class 2 = %.4f, the closer to 0, the better" metrics.PerClassLogLoss.[1] |
| 52 | + printfn " LogLoss for class 3 = %.4f, the closer to 0, the better" metrics.PerClassLogLoss.[2] |
| 53 | + printfn "************************************************************" |
| 54 | + |
| 55 | + |
| 56 | + let private calculateStandardDeviation (values : float array) = |
| 57 | + let average = values |> Array.average |
| 58 | + let sumOfSquaresOfDifferences = values |> Array.map(fun v -> (v - average) * (v - average)) |> Array.sum |
| 59 | + let standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / float (values.Length-1)) |
| 60 | + standardDeviation; |
| 61 | + |
| 62 | + let calculateConfidenceInterval95 (values : float array) = |
| 63 | + let confidenceInterval95 = 1.96 * calculateStandardDeviation(values) / Math.Sqrt(float (values.Length-1)); |
| 64 | + confidenceInterval95 |
| 65 | + |
| 66 | + let printMulticlassClassificationFoldsAverageMetrics algorithmName (crossValResults : (MultiClassClassifierEvaluator.Result * ITransformer * IDataView) array) = |
| 67 | + |
| 68 | + let metricsInMultipleFolds = crossValResults |> Array.map(fun (metrics, model, scoredTestData) -> metrics) |
| 69 | + |
| 70 | + let microAccuracyValues = metricsInMultipleFolds |> Array.map(fun m -> m.AccuracyMicro) |
| 71 | + let microAccuracyAverage = microAccuracyValues |> Array.average |
| 72 | + let microAccuraciesStdDeviation = calculateStandardDeviation microAccuracyValues |
| 73 | + let microAccuraciesConfidenceInterval95 = calculateConfidenceInterval95 microAccuracyValues |
| 74 | + |
| 75 | + let macroAccuracyValues = metricsInMultipleFolds |> Array.map(fun m -> m.AccuracyMacro) |
| 76 | + let macroAccuracyAverage = macroAccuracyValues |> Array.average |
| 77 | + let macroAccuraciesStdDeviation = calculateStandardDeviation macroAccuracyValues |
| 78 | + let macroAccuraciesConfidenceInterval95 = calculateConfidenceInterval95 macroAccuracyValues |
| 79 | + |
| 80 | + let logLossValues = metricsInMultipleFolds |> Array.map (fun m -> m.LogLoss) |
| 81 | + let logLossAverage = logLossValues |> Array.average |
| 82 | + let logLossStdDeviation = calculateStandardDeviation logLossValues |
| 83 | + let logLossConfidenceInterval95 = calculateConfidenceInterval95 logLossValues |
| 84 | + |
| 85 | + let logLossReductionValues = metricsInMultipleFolds |> Array.map (fun m -> m.LogLossReduction) |
| 86 | + let logLossReductionAverage = logLossReductionValues |> Array.average |
| 87 | + let logLossReductionStdDeviation = calculateStandardDeviation logLossReductionValues |
| 88 | + let logLossReductionConfidenceInterval95 = calculateConfidenceInterval95 logLossReductionValues |
| 89 | + |
| 90 | + printfn "*************************************************************************************************************" |
| 91 | + printfn "* Metrics for %s Multi-class Classification model " algorithmName |
| 92 | + printfn "*------------------------------------------------------------------------------------------------------------" |
| 93 | + printfn "* Average MicroAccuracy: %.3f - Standard deviation: (%.3f) - Confidence Interval 95%%: (%.3f)" microAccuracyAverage microAccuraciesStdDeviation microAccuraciesConfidenceInterval95 |
| 94 | + printfn "* Average MacroAccuracy: %.3f - Standard deviation: (%.3f) - Confidence Interval 95%%: (%.3f)" macroAccuracyAverage macroAccuraciesStdDeviation macroAccuraciesConfidenceInterval95 |
| 95 | + printfn "* Average LogLoss: %.3f - Standard deviation: (%.3f) - Confidence Interval 95%%: (%.3f)" logLossAverage logLossStdDeviation logLossConfidenceInterval95 |
| 96 | + printfn "* Average LogLossReduction: %.3f - Standard deviation: (%.3f) - Confidence Interval 95%%: (%.3f)" logLossReductionAverage logLossReductionStdDeviation logLossReductionConfidenceInterval95 |
| 97 | + printfn "*************************************************************************************************************" |
| 98 | + |
| 99 | + let printClusteringMetrics name (metrics : ClusteringEvaluator.Result) = |
| 100 | + printfn "*************************************************" |
| 101 | + printfn "* Metrics for %s clustering model " name |
| 102 | + printfn "*------------------------------------------------" |
| 103 | + printfn "* AvgMinScore: %.4f" metrics.AvgMinScore |
| 104 | + printfn "* DBI is: %.4f" metrics.Dbi |
| 105 | + printfn "*************************************************" |
| 106 | + |
| 107 | + let consoleWriteHeader (lines : string array) = |
| 108 | + let defaultColor = Console.ForegroundColor |
| 109 | + Console.ForegroundColor <- ConsoleColor.Yellow |
| 110 | + printfn " " |
| 111 | + for line in lines do |
| 112 | + printfn "%s" line |
| 113 | + let maxLength = lines |> Array.map(fun x -> x.Length) |> Array.max |
| 114 | + printfn "%s" (new string('#', maxLength)) |
| 115 | + Console.ForegroundColor <- defaultColor |
| 116 | + |
| 117 | + let peekDataViewInConsole<'TObservation when 'TObservation : (new : unit -> 'TObservation) and 'TObservation : not struct> (mlContext : MLContext) (dataView : IDataView) (pipeline : IEstimator<ITransformer>) numberOfRows = |
| 118 | + |
| 119 | + let msg = sprintf "Peek data in DataView: Showing %d rows with the columns specified by TObservation class" numberOfRows |
| 120 | + consoleWriteHeader [| msg |] |
| 121 | + |
| 122 | + //https://github.com/dotnet/machinelearning/blob/master/docs/code/MlNetCookBook.md#how-do-i-look-at-the-intermediate-data |
| 123 | + let transformer = pipeline.Fit dataView |
| 124 | + let transformedData = transformer.Transform dataView |
| 125 | + |
| 126 | + // 'transformedData' is a 'promise' of data, lazy-loading. Let's actually read it. |
| 127 | + // Convert to an enumerable of user-defined type. |
| 128 | + let someRows = |
| 129 | + transformedData.AsEnumerable<'TObservation>(mlContext, reuseRowObject = false) |
| 130 | + // Take the specified number of rows |
| 131 | + |> Seq.take numberOfRows |
| 132 | + // Convert to List |
| 133 | + |> Seq.toList |
| 134 | + |
| 135 | + someRows |
| 136 | + |> List.iter(fun row -> |
| 137 | + |
| 138 | + let lineToPrint = |
| 139 | + row.GetType().GetFields(BindingFlags.Instance ||| BindingFlags.Static ||| BindingFlags.NonPublic ||| BindingFlags.Public) |
| 140 | + |> Array.map(fun field -> sprintf "| %s: %O" field.Name (field.GetValue(row))) |
| 141 | + |> Array.fold (+) "Row--> " |
| 142 | + |
| 143 | + printfn "%s" lineToPrint |
| 144 | + ) |
| 145 | + |
| 146 | + someRows |
| 147 | + |
| 148 | + let peekVectorColumnDataInConsole (mlContext : MLContext) columnName (dataView : IDataView) (pipeline : IEstimator<ITransformer>) numberOfRows = |
| 149 | + let msg = sprintf "Peek data in DataView: : Show %d rows with just the '%s' column" numberOfRows columnName |
| 150 | + consoleWriteHeader [| msg |] |
| 151 | + |
| 152 | + let transformer = pipeline.Fit dataView |
| 153 | + let transformedData = transformer.Transform dataView |
| 154 | + |
| 155 | + // Extract the 'Features' column. |
| 156 | + let someColumnData = |
| 157 | + transformedData.GetColumn<float32[]>(mlContext, columnName) |
| 158 | + |> Seq.take numberOfRows |
| 159 | + |> Seq.toList |
| 160 | + |
| 161 | + // print to console the peeked rows |
| 162 | + someColumnData |
| 163 | + |> List.iter(fun row -> |
| 164 | + let concatColumn = |
| 165 | + row |
| 166 | + |> Array.map string |
| 167 | + |> Array.fold (+) " " |
| 168 | + printfn "%s" concatColumn |
| 169 | + ) |
| 170 | + |
| 171 | + someColumnData; |
| 172 | + |
| 173 | + let consoleWriterSection (lines : string array) = |
| 174 | + let defaultColor = Console.ForegroundColor |
| 175 | + Console.ForegroundColor <- ConsoleColor.Blue |
| 176 | + printfn " " |
| 177 | + lines |
| 178 | + |> Array.iter (printfn "%s") |
| 179 | + |
| 180 | + let maxLength = lines |> Array.map(fun x -> x.Length) |> Array.max |
| 181 | + printfn "%s" (new string('-', maxLength)) |
| 182 | + Console.ForegroundColor <- defaultColor |
| 183 | + |
| 184 | + let consolePressAnyKey () = |
| 185 | + let defaultColor = Console.ForegroundColor |
| 186 | + Console.ForegroundColor <- ConsoleColor.Green |
| 187 | + printfn " " |
| 188 | + printfn "Press any key to finish." |
| 189 | + Console.ForegroundColor <- defaultColor |
| 190 | + Console.ReadKey() |> ignore |
| 191 | + |
| 192 | + let consoleWriteException (lines : string array) = |
| 193 | + let defaultColor = Console.ForegroundColor |
| 194 | + Console.ForegroundColor <- ConsoleColor.Red |
| 195 | + let exceptionTitle = "EXCEPTION" |
| 196 | + printfn " " |
| 197 | + printfn "%s" exceptionTitle |
| 198 | + printfn "%s" (new string('#', exceptionTitle.Length)) |
| 199 | + Console.ForegroundColor <- defaultColor |
| 200 | + lines |
| 201 | + |> Array.iter (printfn "%s") |
| 202 | + |
| 203 | + let consoleWriteWarning (lines : string array) = |
| 204 | + let defaultColor = Console.ForegroundColor |
| 205 | + Console.ForegroundColor <- ConsoleColor.DarkMagenta |
| 206 | + let warningTitle = "WARNING" |
| 207 | + printfn " " |
| 208 | + printfn "%s" warningTitle |
| 209 | + printfn "%s" (new string('#', warningTitle.Length)) |
| 210 | + Console.ForegroundColor <- defaultColor |
| 211 | + lines |
| 212 | + |> Array.iter (printfn "%s") |
0 commit comments