Skip to content

Commit

Permalink
Merge pull request #4 from square/cmachak/align-1.5.0
Browse files Browse the repository at this point in the history
cmachak/align-1.5.0
  • Loading branch information
xmachak authored Jan 10, 2024
2 parents d9067e6 + 5994a2a commit 285c728
Show file tree
Hide file tree
Showing 8 changed files with 787 additions and 585 deletions.
55 changes: 23 additions & 32 deletions cascade/executors/databricks/executor.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
from dataclasses import dataclass
import importlib
import os
import sys
import threading
import time
from types import ModuleType
from typing import Callable, Iterable, Optional

from cascade.executors.databricks.resource import DatabricksSecret
from cascade.executors.executor import Executor

try:
import cloudpickle
except ImportError:
import pickle as cloudpickle # Databricks renames cloudpickle to pickle in Runtimes 11 + # noqa: E501

import importlib
import os
import sys
import threading
import time
import s3fs
from dataclasses import dataclass
from slugify import slugify

from databricks_cli.cluster_policies.api import ClusterPolicyApi
from databricks_cli.runs.api import RunsApi
from databricks_cli.sdk.api_client import ApiClient
import s3fs
from slugify import slugify

from cascade.executors.databricks.resource import DatabricksSecret
from cascade.executors.databricks.job import DatabricksJob
from cascade.executors.databricks.resource import DatabricksResource
from cascade.prefect import get_prefect_logger
from cascade.utils import _base_module

if sys.version_info.major >= 3 and sys.version_info.minor >= 9:
from importlib.resources import files
Expand Down Expand Up @@ -146,7 +146,7 @@ def fs(self):
self._fs = s3fs.S3FileSystem(**self.resource.s3_credentials)
break
except KeyError:
logger.info(f"Waiting {wait} seconds to retry STS")
self.logger.info(f"Waiting {wait} seconds to retry STS")
n_retries += 1
time.sleep(wait)
wait *= 1.5
Expand Down Expand Up @@ -178,27 +178,17 @@ def cloudpickle_by_value(self) -> Iterable[ModuleType]:
Iterable[str]
Set of modules to pickle by value
"""
try:
modules_to_pickle = set()
for module in self.resource.cloud_pickle_by_value or []:
try:
modules_to_pickle.add(importlib.import_module(module))
except ModuleNotFoundError:
raise RuntimeError(
f"Unable to pickle {module} due to module not being "
"found in current Python environment."
)
except ImportError:
raise RuntimeError(
f"Unable to pickle {module} due to import error."
)
if self.resource.cloud_pickle_infer_base_module:
module = _base_module(self.func)
modules_to_pickle.add(module)
except AttributeError as e:
# This happens when the function is not part of a module,
# but cloudpickle will typically handle that
self.logger.warn(f"Failed to infer base module of function: {e}")
modules_to_pickle = set()
for module in self.resource.cloud_pickle_by_value or []:
try:
modules_to_pickle.add(importlib.import_module(module))
except ModuleNotFoundError:
raise RuntimeError(
f"Unable to pickle {module} due to module not being "
"found in current Python environment."
)
except ImportError:
raise RuntimeError(f"Unable to pickle {module} due to import error.")
return modules_to_pickle

@property
Expand Down Expand Up @@ -312,6 +302,7 @@ def _start(self):
client = self.runs_api
job = self.create_job()
databricks_payload = job.create_payload()
self.logger.info(f"Databricks job payload: {databricks_payload}")

self.active_job = client.submit_run(
databricks_payload, version=DATABRICKS_API_VERSION
Expand Down
9 changes: 3 additions & 6 deletions cascade/executors/databricks/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from dataclasses import dataclass
from importlib.metadata import version
from typing import Optional
from uuid import uuid4

from cascade.executors.databricks.resource import (
DatabricksAutoscaleConfig,
Expand All @@ -26,9 +25,8 @@ class DatabricksJob:
storage_path: str
The full path to the directory for assets (input, output) on AWS (includes storage_key)
storage_key: str
A key suffixed to the storage location to ensure a unique path for each job
idempotency_token: str
A cache key for the job in Databricks; for checking the status of an active job
A key suffixed to the storage location to ensure a unique path for each job.
Also used as the `idempotency_token` in the Job API request.
cluster_policy_id: str
Generated by default by looking up using team name
existing_cluster_id: str
Expand All @@ -47,7 +45,6 @@ class DatabricksJob:
storage_key: str
run_path: str
cluster_policy_id: str
idempotency_token: str = uuid4().hex
existing_cluster_id: Optional[str] = None
timeout_seconds: str = 86400

Expand All @@ -61,7 +58,7 @@ def create_payload(self):
"tasks": [task],
"run_name": self.name,
"timeout_seconds": self.timeout_seconds,
"idempotency_token": self.idempotency_token,
"idempotency_token": self.storage_key,
"access_control_list": [
{
"group_name": self.resource.group_name,
Expand Down
35 changes: 19 additions & 16 deletions cascade/executors/databricks/run.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
import os
import pickle
import sys

import boto3

try:
import cloudpickle
except ImportError:
Expand All @@ -16,34 +17,36 @@


def run():
bucket_location, storage_key = sys.argv[1], sys.argv[2]
mountpoint = f"/mnt/cascade/{storage_key}/"
dbfs_mountpoint = "/dbfs" + mountpoint

dbutils.fs.mount( # dbutils is populated in cluster namespace without import # noqa: E501, F821
bucket_location.replace("s3://", "s3n://"),
mountpoint,
)
try:
with open(os.path.join(dbfs_mountpoint, INPUT_FILENAME), "rb") as f:
func = cloudpickle.load(f)
bucket_location, _ = sys.argv[1], sys.argv[2]
s3_bucket, object_path = bucket_location.replace("s3://", "").split("/", 1)

try:
s3 = boto3.resource("s3")
func = cloudpickle.loads(
s3.Bucket(s3_bucket)
.Object(f"{object_path}/{INPUT_FILENAME}")
.get()["Body"]
.read()
)
logger.info("Starting execution")

result = func()

logger.info(f"Saving output of task to {bucket_location}/{OUTPUT_FILENAME}")
try:
with open(os.path.join(dbfs_mountpoint, OUTPUT_FILENAME), "wb") as f:
pickle.dump(result, f)
s3.Bucket(s3_bucket).Object(f"{object_path}/{OUTPUT_FILENAME}").put(
Body=pickle.dumps(result)
)
except RuntimeError as e:
logger.error(
"Failed to serialize user function return value. Be sure not to return "
"Spark objects from user functions. For example, you should convert "
"Spark dataframes to Pandas dataframes before returning."
)
raise e
finally:
dbutils.fs.unmount(mountpoint) # noqa: F821
except RuntimeError as e:
logger.error("Failed to execute user function")
raise e


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion cascade/executors/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _result(self):
try:
with self.fs.open(self.output_filepath, "rb") as f:
result = cloudpickle.load(f)
self.fs.rm(self.storage_location, recursive=True)
self.fs.rm(self.storage_path, recursive=True)
except FileNotFoundError:
raise FileNotFoundError(
f"Could not find output file {self.output_filepath}"
Expand Down
15 changes: 14 additions & 1 deletion cascade/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
get_prefect_logger,
is_prefect_cloud_deployment,
)
from cascade.utils import wrapped_partial
from cascade.utils import _infer_base_module, wrapped_partial

RESERVED_ARG_PREFIX = "remote_"

Expand Down Expand Up @@ -217,6 +217,19 @@ def remote_func(*args, **kwargs):
or "DatabricksResource" in type(resource).__name__
):
prefect_logger.info("Executing task with DatabricksResource.")
failed_to_infer_base = (
"Unable to infer base module of function. Specify "
"the base module in the `cloud_pickle_by_value` attribute "
"of the DatabricksResource object if necessary."
)
if resource.cloud_pickle_infer_base_module:
base_module_name = _infer_base_module(func)
# if base module is __main__ or None, it can't be registered
if base_module_name.startswith("__") or base_module_name is None:
prefect_logger.warn(failed_to_infer_base)
else:
resource.cloud_pickle_by_value.append(base_module_name)

executor = DatabricksExecutor(
func=packed_func,
resource=resource,
Expand Down
21 changes: 13 additions & 8 deletions cascade/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import ast
from functools import partial
import importlib
import inspect
from inspect import signature
import itertools
import logging
import pkgutil
import sys
from typing import List

import prefect
Expand Down Expand Up @@ -51,12 +48,20 @@ def _get_object_args(obj: object):
return list(signature(obj).parameters.keys())


def _base_module(func):
def _infer_base_module(func):
"""
Inspects the function to find the base module in which it was defined.
Args:
func (Callable): a function
Returns:
str: the name of the base module in which the function was defined
"""
func_module = inspect.getmodule(func)
base_name, *_ = func_module.__name__.partition(".")
base = sys.modules[base_name]
for _, name, _ in pkgutil.walk_packages(base.__path__):
return importlib.import_module(f"{base_name}.{name}")
try:
base_name, *_ = func_module.__name__.partition(".")
return base_name
except AttributeError:
return None


def wrapped_partial(func, *args, **kwargs):
Expand Down
Loading

0 comments on commit 285c728

Please sign in to comment.