Skip to content

Fix MLFlowLogger.save_dir Windows file URI handling (Fixes #20972) #20988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/lightning/pytorch/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,18 @@ def save_dir(self) -> Optional[str]:

"""
if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):
return self._tracking_uri[len(LOCAL_FILE_URI_PREFIX) :]
# Handle both proper file URIs (file:///path) and legacy format (file:/path)
uri_without_prefix = self._tracking_uri[len(LOCAL_FILE_URI_PREFIX) :]

# If it starts with ///, it's a proper file URI, use urlparse
if uri_without_prefix.startswith("///"):
from urllib.parse import urlparse
from urllib.request import url2pathname

parsed_uri = urlparse(self._tracking_uri)
return url2pathname(parsed_uri.path)
# Legacy format: file:/path or file:./path - return as-is
return uri_without_prefix
return None

@property
Expand Down
40 changes: 40 additions & 0 deletions tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,43 @@ def test_set_tracking_uri(mlflow_mock):
mlflow_mock.set_tracking_uri.assert_not_called()
_ = logger.experiment
mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri")


@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_save_dir_file_uri_handling(mlflow_mock):
"""Test that save_dir correctly handles file URIs, especially on Windows."""
import platform

# Test proper Windows-style absolute file URI (the main fix)
logger_win = MLFlowLogger(tracking_uri="file:///C:/Dev/example/mlruns")
result_win = logger_win.save_dir
expected_win = "C:\\Dev\\example\\mlruns" if platform.system() == "Windows" else "/C:/Dev/example/mlruns"
assert result_win == expected_win

# Test proper Unix-style absolute file URI
logger_unix = MLFlowLogger(tracking_uri="file:///home/user/mlruns")
result_unix = logger_unix.save_dir
expected_unix = "\\home\\user\\mlruns" if platform.system() == "Windows" else "/home/user/mlruns"
assert result_unix == expected_unix

# Test proper file URI with special characters and spaces
logger_special = MLFlowLogger(tracking_uri="file:///path/with%20spaces/mlruns")
result_special = logger_special.save_dir
expected_special = "\\path\\with spaces\\mlruns" if platform.system() == "Windows" else "/path/with spaces/mlruns"
assert result_special == expected_special

# Test legacy format used by constructor (file:/path - should return as-is)
logger_legacy = MLFlowLogger(tracking_uri="file:/tmp/mlruns")
result_legacy = logger_legacy.save_dir
expected_legacy = "/tmp/mlruns"
assert result_legacy == expected_legacy

# Test legacy relative format
logger_rel = MLFlowLogger(tracking_uri="file:./mlruns")
result_rel = logger_rel.save_dir
expected_rel = "./mlruns"
assert result_rel == expected_rel

# Test non-file URI (should return None)
logger_http = MLFlowLogger(tracking_uri="http://localhost:8080")
assert logger_http.save_dir is None
Loading