diff --git a/Project.toml b/Project.toml index c41e46153..30d814efd 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6" +MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" diff --git a/src/FluxTraining.jl b/src/FluxTraining.jl index 8b5aab1ea..5541d28cc 100644 --- a/src/FluxTraining.jl +++ b/src/FluxTraining.jl @@ -25,6 +25,7 @@ using ChainRulesCore import ParameterSchedulers import ParameterSchedulers: Sequence, Shifted, Sin using TensorBoardLogger: TBLogger, log_value, log_image, log_text, log_histogram, tb_overwrite +using MLFlowClient using Zygote: Grads, gradient using ValueHistories using DataStructures: DefaultDict, PriorityQueue, enqueue!, dequeue! @@ -48,6 +49,7 @@ include("./callbacks/execution.jl") include("./callbacks/logging/Loggables.jl") include("./callbacks/logging/logger.jl") include("./callbacks/logging/tensorboard.jl") +include("./callbacks/logging/mlflow.jl") include("./callbacks/logging/checkpointer.jl") @@ -111,6 +113,7 @@ export AbstractCallback, LogHyperParams, LogVisualization, TensorBoardBackend, + MLFlowBackend, StopOnNaNLoss, LearningRate, throttle, diff --git a/src/callbacks/logging/mlflow.jl b/src/callbacks/logging/mlflow.jl new file mode 100644 index 000000000..47e17bdfc --- /dev/null +++ b/src/callbacks/logging/mlflow.jl @@ -0,0 +1,27 @@ +using MLFlowClient + +""" + MLFlowBackend(tracking_uri, experiment_name[; kwargs...]) + +MLFlow backend for logging metrics. Creates a new experiment if it doesn't exist +and starts a new run. +""" +struct MLFlowBackend <: LoggerBackend + mlf::MLFlowClient.MLFlow + experiment::MLFlowClient.MLFlowExperiment + run::MLFlowClient.MLFlowRun + function MLFlowBackend(tracking_uri::String, experiment_name::String; kwargs...) + mlf = MLFlowClient.MLFlow(tracking_uri) + experiment = MLFlowClient.getorcreateexperiment(mlf, experiment_name) + run = MLFlowClient.createrun(mlf, experiment) + return new(mlf, experiment, run) + end +end + +Base.show(io::IO, backend::MLFlowBackend) = print( + io, "MLFlowBackend(", backend.mlf.apiroot, ", ", backend.experiment.name, ")") + +function log_to(backend::MLFlowBackend, value::Loggables.Value, name, step; group = ()) + name = _combinename(name, group) + MLFlowClient.logmetric(backend.mlf, backend.run, name, value.data, step=step) +end