Skip to content

Commit

Permalink
src/callbacks/logging/mlflow.jl: Add basic MLFlowBackend.
Browse files Browse the repository at this point in the history
  • Loading branch information
mashu committed Sep 12, 2024
1 parent 2fa2fb3 commit 7148e87
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6"
MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
Expand Down
3 changes: 3 additions & 0 deletions src/FluxTraining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ using ChainRulesCore
import ParameterSchedulers
import ParameterSchedulers: Sequence, Shifted, Sin
using TensorBoardLogger: TBLogger, log_value, log_image, log_text, log_histogram, tb_overwrite
using MLFlowClient
using Zygote: Grads, gradient
using ValueHistories
using DataStructures: DefaultDict, PriorityQueue, enqueue!, dequeue!
Expand All @@ -48,6 +49,7 @@ include("./callbacks/execution.jl")
include("./callbacks/logging/Loggables.jl")
include("./callbacks/logging/logger.jl")
include("./callbacks/logging/tensorboard.jl")
include("./callbacks/logging/mlflow.jl")
include("./callbacks/logging/checkpointer.jl")


Expand Down Expand Up @@ -111,6 +113,7 @@ export AbstractCallback,
LogHyperParams,
LogVisualization,
TensorBoardBackend,
MLFlowBackend,
StopOnNaNLoss,
LearningRate,
throttle,
Expand Down
27 changes: 27 additions & 0 deletions src/callbacks/logging/mlflow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using MLFlowClient

"""
MLFlowBackend(tracking_uri, experiment_name[; kwargs...])
MLFlow backend for logging metrics. Creates a new experiment if it doesn't exist
and starts a new run.
"""
struct MLFlowBackend <: LoggerBackend
mlf::MLFlowClient.MLFlow
experiment::MLFlowClient.MLFlowExperiment
run::MLFlowClient.MLFlowRun
function MLFlowBackend(tracking_uri::String, experiment_name::String; kwargs...)
mlf = MLFlowClient.MLFlow(tracking_uri)
experiment = MLFlowClient.getorcreateexperiment(mlf, experiment_name)
run = MLFlowClient.createrun(mlf, experiment)
return new(mlf, experiment, run)
end
end

Base.show(io::IO, backend::MLFlowBackend) = print(
io, "MLFlowBackend(", backend.mlf.apiroot, ", ", backend.experiment.name, ")")

function log_to(backend::MLFlowBackend, value::Loggables.Value, name, step; group = ())
name = _combinename(name, group)
MLFlowClient.logmetric(backend.mlf, backend.run, name, value.data, step=step)
end

0 comments on commit 7148e87

Please sign in to comment.