Skip to content

Commit

Permalink
Make spark_udf use NFS to broadcast model to spark executor on databr…
Browse files Browse the repository at this point in the history
…icks runtime and spark connect mode (mlflow#10463)

Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: mlflow-automation <[email protected]>
Co-authored-by: mlflow-automation <[email protected]>
  • Loading branch information
WeichenXu123 and mlflow-automation authored Nov 21, 2023
1 parent 0f57a1e commit 800b8f3
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions mlflow/pyfunc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,10 +1239,6 @@ def spark_udf(
from mlflow.utils._spark_utils import _SparkDirectoryDistributor

is_spark_connect = _is_spark_connect()
if is_spark_connect and env_manager in (_EnvManager.VIRTUALENV, _EnvManager.CONDA):
raise MlflowException.invalid_parameter_value(
f"Environment manager {env_manager!r} is not supported in Spark connect mode.",
)
# Used in test to force install local version of mlflow when starting a model server
mlflow_home = os.environ.get("MLFLOW_HOME")
openai_env_vars = mlflow.openai._OpenAIEnvVar.read_environ()
Expand All @@ -1260,6 +1256,22 @@ def spark_udf(
is_spark_in_local_mode or should_use_nfs or is_spark_connect
)

# For spark connect mode,
# If client code is executed in databricks runtime and NFS is available,
# we save model to NFS temp directory in the driver
# and load the model in the executor.
should_spark_connect_use_nfs = is_in_databricks_runtime() and should_use_nfs

if (
is_spark_connect
and env_manager in (_EnvManager.VIRTUALENV, _EnvManager.CONDA)
and not should_spark_connect_use_nfs
):
raise MlflowException.invalid_parameter_value(
f"Environment manager {env_manager!r} is not supported in Spark connect mode "
"when either non-Databricks environment is in use or NFS is unavailable.",
)

local_model_path = _download_artifact_from_uri(
artifact_uri=model_uri,
output_path=_create_model_downloading_tmp_dir(should_use_nfs),
Expand Down Expand Up @@ -1580,7 +1592,7 @@ def batch_predict_fn(pdf, params=None):
return client.invoke(pdf).get_predictions()

elif env_manager == _EnvManager.LOCAL:
if is_spark_connect:
if is_spark_connect and not should_spark_connect_use_nfs:
model_path = os.path.join(
tempfile.gettempdir(),
"mlflow",
Expand Down

0 comments on commit 800b8f3

Please sign in to comment.