Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add imagespec to hash to auto cache #2944

Draft
wants to merge 2 commits into
base: danielsola/se-256-add-private-module-crawling
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions flytekit/core/auto_cache.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,76 @@
from typing import Any, Callable, Protocol, runtime_checkable
from dataclasses import dataclass
from typing import Any, Callable, Optional, Protocol, Union, runtime_checkable

from flytekit.image_spec.image_spec import ImageSpec


@dataclass
class VersionParameters:
"""
Parameters used for version hash generation.

Args:
func (Optional[Callable]): The function to generate a version for
container_image (Optional[Union[str, ImageSpec]]): The container image to generate a version for
"""

func: Optional[Callable[..., Any]] = None
container_image: Optional[Union[str, ImageSpec]] = None


@runtime_checkable
class AutoCache(Protocol):
"""
A protocol that defines the interface for a caching mechanism
that generates a version hash of a function based on its source code.

Attributes:
salt (str): A string used to add uniqueness to the generated hash. Default is "salt".

Methods:
get_version(func: Callable[..., Any]) -> str:
Given a function, generates a version hash based on its source code and the salt.
"""

def __init__(self, salt: str = "salt") -> None:
salt: str

def get_version(self, params: VersionParameters) -> str:
"""
Initialize the AutoCache instance with a salt value.
Generate a version hash based on the provided parameters.

Args:
salt (str): A string to be used as the salt in the hashing process. Defaults to "salt".
params (VersionParameters): Parameters to use for hash generation.

Returns:
str: The generated version hash.
"""
...


class CachePolicy:
"""
A class that combines multiple caching mechanisms to generate a version hash.

Args:
*cache_objects: Variable number of AutoCache instances
salt: Optional salt string to add uniqueness to the hash
"""

def __init__(self, *cache_objects: AutoCache, salt: str = "") -> None:
self.cache_objects = cache_objects
self.salt = salt

def get_version(self, func: Callable[..., Any]) -> str:
def get_version(self, params: VersionParameters) -> str:
"""
Generate a version hash for the provided function.
Generate a version hash using all cache objects.

Args:
func (Callable[..., Any]): A callable function whose version hash needs to be generated.
params (VersionParameters): Parameters to use for hash generation.

Returns:
str: The SHA-256 hash of the function's source code combined with the salt.
str: The combined hash from all cache objects.
"""
...
task_hash = ""
for cache_instance in self.cache_objects:
# Apply the policy's salt to each cache instance
cache_instance.salt = self.salt
task_hash += cache_instance.get_version(params)

# Generate SHA-256 hash
import hashlib

hash_obj = hashlib.sha256(task_hash.encode())
return hash_obj.hexdigest()
36 changes: 19 additions & 17 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import update_wrapper
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload

from flytekit.core.auto_cache import AutoCache
from flytekit.core.auto_cache import CachePolicy, VersionParameters
from flytekit.core.utils import str2bool

try:
Expand Down Expand Up @@ -100,7 +100,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction
def task(
_task_function: None = ...,
task_config: Optional[T] = ...,
cache: Union[bool, list[AutoCache]] = ...,
cache: Union[bool, CachePolicy] = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
Expand Down Expand Up @@ -136,9 +136,9 @@ def task(

@overload
def task(
_task_function: Callable[P, FuncOut],
_task_function: Callable[..., FuncOut],
task_config: Optional[T] = ...,
cache: Union[bool, list[AutoCache]] = ...,
cache: Union[bool, CachePolicy] = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
Expand Down Expand Up @@ -169,13 +169,13 @@ def task(
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ...
) -> Union[Callable[..., FuncOut], PythonFunctionTask[T]]: ...


def task(
_task_function: Optional[Callable[P, FuncOut]] = None,
_task_function: Optional[Callable[..., FuncOut]] = None,
task_config: Optional[T] = None,
cache: Union[bool, list[AutoCache]] = False,
cache: Union[bool, CachePolicy] = False,
cache_serialize: bool = False,
cache_version: str = "",
cache_ignore_input_vars: Tuple[str, ...] = (),
Expand Down Expand Up @@ -213,8 +213,8 @@ def task(
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
) -> Union[
Callable[P, FuncOut],
Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]],
Callable[..., FuncOut],
Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]],
PythonFunctionTask[T],
]:
"""
Expand Down Expand Up @@ -343,17 +343,19 @@ def launch_dynamically():
:param accelerator: The accelerator to use for this task.
"""

def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
if isinstance(cache, list) and all(isinstance(item, AutoCache) for item in cache):
cache_versions = [item.get_version() for item in cache]
task_hash = "".join(cache_versions)
def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
if isinstance(cache, CachePolicy):
params = VersionParameters(func=fn, container_image=container_image)
cache_version_val = cache.get_version(params=params)
cache_val = True
else:
task_hash = ""
cache_val = cache
cache_version_val = cache_version

_metadata = TaskMetadata(
cache=cache,
cache=cache_val,
cache_serialize=cache_serialize,
cache_version=cache_version if not task_hash else task_hash,
cache_version=cache_version_val,
cache_ignore_input_vars=cache_ignore_input_vars,
retries=retries,
interruptible=interruptible,
Expand Down Expand Up @@ -439,7 +441,7 @@ def wrapper(fn) -> ReferenceTask:
return wrapper


def decorate_function(fn: Callable[P, Any]) -> Callable[P, Any]:
def decorate_function(fn: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorates the task with additional functionality if necessary.

Expand Down
10 changes: 5 additions & 5 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,23 +843,23 @@ def workflow(

@overload
def workflow(
_workflow_function: Callable[P, FuncOut],
_workflow_function: Callable[..., FuncOut],
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
default_options: Optional[Options] = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ...
) -> Union[Callable[..., FuncOut], PythonFunctionWorkflow]: ...


def workflow(
_workflow_function: Optional[Callable[P, FuncOut]] = None,
_workflow_function: Optional[Callable[..., FuncOut]] = None,
failure_policy: Optional[WorkflowFailurePolicy] = None,
interruptible: bool = False,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
default_options: Optional[Options] = None,
) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
) -> Union[Callable[..., FuncOut], Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
of tasks using the data flow between tasks.
Expand Down Expand Up @@ -894,7 +894,7 @@ def workflow(
the labels and annotations are allowed to be set as defaults.
"""

def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
def wrapper(fn: Callable[..., FuncOut]) -> PythonFunctionWorkflow:
workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY)

workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import textwrap
from typing import Any, Callable

from flytekit.core.auto_cache import VersionParameters


class CacheFunctionBody:
"""
Expand All @@ -27,8 +29,10 @@ def __init__(self, salt: str = "salt") -> None:
"""
self.salt = salt

def get_version(self, func: Callable[..., Any]) -> str:
return self._get_version(func=func)
def get_version(self, params: VersionParameters) -> str:
if params.func is None:
raise ValueError("Function-based cache requires a function parameter")
return self._get_version(func=params.func)

def _get_version(self, func: Callable[..., Any]) -> str:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import hashlib

from flytekit.core.auto_cache import VersionParameters
from flytekit.image_spec.image_spec import ImageSpec


class CacheImage:
def __init__(self, salt: str):
self.salt = salt

def get_version(self, params: VersionParameters) -> str:
if params.container_image is None:
raise ValueError("Image-based cache requires a container_image parameter")

# If the image is an ImageSpec, combine tag with salt
if isinstance(params.container_image, ImageSpec):
combined = params.container_image.tag + self.salt
return hashlib.sha256(combined.encode("utf-8")).hexdigest()

# If the image is a string, combine with salt
combined = params.container_image + self.salt
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from pathlib import Path
from typing import Any, Callable, Set, Union

from flytekit.core.auto_cache import VersionParameters


@contextmanager
def temporarily_add_to_syspath(path):
Expand All @@ -24,9 +26,12 @@ def __init__(self, salt: str, root_dir: str):
self.salt = salt
self.root_dir = Path(root_dir).resolve()

def get_version(self, func: Callable[..., Any]) -> str:
hash_components = [self._get_version(func)]
dependencies = self._get_function_dependencies(func, set())
def get_version(self, params: VersionParameters) -> str:
if params.func is None:
raise ValueError("Function-based cache requires a function parameter")

hash_components = [self._get_version(params.func)]
dependencies = self._get_function_dependencies(params.func, set())
for dep in dependencies:
hash_components.append(self._get_version(dep))
# Combine all component hashes into a single version hash
Expand Down
36 changes: 26 additions & 10 deletions plugins/flytekit-auto-cache/tests/test_function_body.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dummy_functions.dummy_function import dummy_function
from dummy_functions.dummy_function_comments_formatting_change import dummy_function as dummy_function_comments_formatting_change
from dummy_functions.dummy_function_logic_change import dummy_function as dummy_function_logic_change
from flytekit.core.auto_cache import VersionParameters
from flytekitplugins.auto_cache import CacheFunctionBody


Expand All @@ -11,9 +12,11 @@ def test_get_version_with_same_function_and_salt():
cache1 = CacheFunctionBody(salt="salt")
cache2 = CacheFunctionBody(salt="salt")

params = VersionParameters(func=dummy_function)

# Both calls should return the same hash since the function and salt are the same
version1 = cache1.get_version(dummy_function)
version2 = cache2.get_version(dummy_function)
version1 = cache1.get_version(params)
version2 = cache2.get_version(params)

assert version1 == version2, f"Expected {version1}, but got {version2}"

Expand All @@ -25,9 +28,11 @@ def test_get_version_with_different_salt():
cache1 = CacheFunctionBody(salt="salt1")
cache2 = CacheFunctionBody(salt="salt2")

params = VersionParameters(func=dummy_function)

# The hashes should be different because the salts are different
version1 = cache1.get_version(dummy_function)
version2 = cache2.get_version(dummy_function)
version1 = cache1.get_version(params)
version2 = cache2.get_version(params)

assert version1 != version2, f"Expected different hashes but got the same: {version1}"

Expand All @@ -38,8 +43,12 @@ def test_get_version_with_different_logic():
Test that functions with the same name but different logic produce different hashes.
"""
cache = CacheFunctionBody(salt="salt")
version1 = cache.get_version(dummy_function)
version2 = cache.get_version(dummy_function_logic_change)

params1 = VersionParameters(func=dummy_function)
params2 = VersionParameters(func=dummy_function_logic_change)

version1 = cache.get_version(params1)
version2 = cache.get_version(params2)

assert version1 != version2, (
f"Hashes should be different for functions with same name but different logic. "
Expand All @@ -61,8 +70,11 @@ def test_get_version_with_different_function_names():
"""
cache = CacheFunctionBody(salt="salt")

version1 = cache.get_version(function_one)
version2 = cache.get_version(function_two)
params1 = VersionParameters(func=function_one)
params2 = VersionParameters(func=function_two)

version1 = cache.get_version(params1)
version2 = cache.get_version(params2)

assert version1 != version2, (
f"Hashes should be different for functions with different names. "
Expand All @@ -76,8 +88,12 @@ def test_get_version_with_formatting_changes():
"""

cache = CacheFunctionBody(salt="salt")
version1 = cache.get_version(dummy_function)
version2 = cache.get_version(dummy_function_comments_formatting_change)

params1 = VersionParameters(func=dummy_function)
params2 = VersionParameters(func=dummy_function_comments_formatting_change)

version1 = cache.get_version(params1)
version2 = cache.get_version(params2)

assert version1 == version2, (
f"Hashes should be the same for functions with same name but different formatting. "
Expand Down
Loading
Loading