diff --git a/flytekit/core/artifact.py b/flytekit/core/artifact.py index be846337cd..866dd7f4ec 100644 --- a/flytekit/core/artifact.py +++ b/flytekit/core/artifact.py @@ -14,6 +14,7 @@ from flyteidl.core.literals_pb2 import Literal from flyteidl.core.types_pb2 import LiteralType +from flytekit.loggers import logger from flytekit.models.literals import Literal from flytekit.models.types import LiteralType @@ -24,6 +25,7 @@ # probably worthwhile to add a format field to the date as well # but separating may be hard as it'll need a new element in the URI mapping. TIME_PARTITION = "ds" +TIME_PARTITION_KWARG = "time_partition" class InputsBase(object): @@ -56,6 +58,7 @@ def __init__( self.partitions = partitions self.time_partition = time_partition + # todo: add time partition arg hint def __call__(self, *args, **kwargs): return self.bind_partitions(*args, **kwargs) @@ -64,6 +67,20 @@ def bind_partitions(self, *args, **kwargs) -> ArtifactIDSpecification: if len(args) > 0: raise ValueError("Cannot set partition values by position") + if TIME_PARTITION_KWARG in kwargs: + if not self.artifact.time_partitioned: + raise ValueError("Cannot bind time partition to non-time partitioned artifact") + p = kwargs[TIME_PARTITION_KWARG] + if isinstance(p, datetime.datetime): + self.time_partition = TimePartition(value=art_id.LabelValue(static_value=f"{p}")) + elif isinstance(p, art_id.InputBindingData): + self.time_partition = TimePartition(value=art_id.LabelValue(input_binding=p)) + else: + raise ValueError(f"Time partition needs to be input binding data or static string, not {p}") + # Given the context, shouldn't need to set further reference_artifacts. + + del kwargs[TIME_PARTITION_KWARG] + if len(kwargs) > 0: p = Partitions(None) # k is the partition key, v should be static, or an input to the task or workflow @@ -71,9 +88,9 @@ def bind_partitions(self, *args, **kwargs) -> ArtifactIDSpecification: if k not in self.artifact.partition_keys: raise ValueError(f"Partition key {k} not found in {self.artifact.partition_keys}") if isinstance(v, art_id.InputBindingData): - p.partitions[k] = Partition(art_id.LabelValue(input_binding=v)) + p.partitions[k] = Partition(art_id.LabelValue(input_binding=v), name=k) elif isinstance(v, str): - p.partitions[k] = Partition(art_id.LabelValue(static_value=v)) + p.partitions[k] = Partition(art_id.LabelValue(static_value=v), name=k) else: raise ValueError(f"Partition key {k} needs to be input binding data or static string, not {v}") # Given the context, shouldn't need to set further reference_artifacts. @@ -81,29 +98,12 @@ def bind_partitions(self, *args, **kwargs) -> ArtifactIDSpecification: return self - def bind_time_partition(self, p: Union[datetime.datetime, art_id.InputBindingData]) -> ArtifactIDSpecification: - # See the parallel function in the main Artifact class for more information. - if not self.artifact.time_partitioned: - raise ValueError("Cannot bind time partition to non-time partitioned artifact") - if isinstance(p, datetime.datetime): - self.time_partition = TimePartition(value=art_id.LabelValue(static_value=f"{p}")) - else: - self.time_partition = TimePartition(value=art_id.LabelValue(input_binding=p)) - # Given the context, shouldn't need to set further reference_artifacts. - - return self - def to_partial_artifact_id(self) -> art_id.ArtifactID: # This function should only be called by transform_variable_map artifact_id = self.artifact.to_flyte_idl().artifact_id # Use the partitions from this object, but replacement is not allowed by protobuf, so generate new object - if self.partitions: - p = self.partitions.to_flyte_idl(self.time_partition) - elif self.time_partition: - # Use an empty partitions object to avoid re-implementing to_flyte_idl - p = Partitions(None).to_flyte_idl(self.time_partition) - else: - p = None + p = partitions_to_idl(self.partitions, self.time_partition) + if self.artifact.partition_keys: required = len(self.artifact.partition_keys) required += 1 if self.artifact.time_partitioned else 0 @@ -203,17 +203,11 @@ def to_flyte_idl( ) return aq - temp_partitions = self.partitions or Partitions(None) - # temp_partitions.idl(query_bindings=self.bindings, fulfilled=bindings, input_keys=input_keys) - # there's the list of what this query needs, what it has to fulfill it, and other inputs that can be used. - # logic is. - # - if you need something that isn't bound, then fail. - # - - partition_idl = temp_partitions.to_flyte_idl(self.time_partition, bindings) + p = partitions_to_idl(self.partitions, self.time_partition, bindings) i = art_id.ArtifactID( artifact_key=ak, - partitions=partition_idl, + partitions=p, ) aq = art_id.ArtifactQuery( @@ -233,7 +227,7 @@ def as_uri(self) -> str: class TimePartition(object): def __init__( self, - value: Union[art_id.LabelValue, art_id.InputBindingData, str, datetime.datetime], + value: Union[art_id.LabelValue, art_id.InputBindingData, str, datetime.datetime, None], op: Optional[str] = None, other: Optional[timedelta] = None, ): @@ -243,11 +237,11 @@ def __init__( value = art_id.LabelValue(static_value=f"{value}") elif isinstance(value, art_id.InputBindingData): value = art_id.LabelValue(input_binding=value) - # else should already be a LabelValue + # else should already be a LabelValue or None self.value = value self.op = op self.other = other - self.reference_artifact = None + self.reference_artifact: Optional[Artifact] = None def __add__(self, other: timedelta) -> TimePartition: tp = TimePartition(self.value, op="+", other=other) @@ -263,9 +257,73 @@ def truncate_to_day(self): # raise NotImplementedError("Not implemented yet") return self + def get_idl_partitions_for_trigger(self, bindings: typing.List[Artifact]) -> art_id.Partitions: + if not self.reference_artifact or (self.reference_artifact and self.reference_artifact not in bindings): + # basically if there's no reference artifact, or if the reference artifact isn't + # in the list of triggers, then treat it like normal. + return art_id.Partitions(value={TIME_PARTITION: self.value}) + elif self.reference_artifact in bindings: + idx = bindings.index(self.reference_artifact) + transform = None + if self.op and self.other and isinstance(self.other, timedelta): + transform = str(self.op) + isodate.duration_isoformat(self.other) + lv = art_id.LabelValue( + triggered_binding=art_id.ArtifactBindingData( + index=idx, + partition_key=TIME_PARTITION, + transform=transform, + ) + ) + return art_id.Partitions(value={TIME_PARTITION: lv}) + # investigate if this happens, if not, remove. else + logger.warning(f"Investigate - time partition in trigger with unhandled reference artifact {self}") + return art_id.Partitions(value={TIME_PARTITION: self.value}) + + def to_flyte_idl(self, bindings: Optional[typing.List[Artifact]] = None) -> Optional[art_id.Partitions]: + if bindings and len(bindings) > 0: + return self.get_idl_partitions_for_trigger(bindings) + + if not self.value: + # This is only for triggers - the backend needs to know of the existence of a time partition + return art_id.Partitions(value={TIME_PARTITION: art_id.LabelValue(static_value="")}) + + return art_id.Partitions(value={TIME_PARTITION: self.value}) + + +def merge_idl_partitions( + p_idl: Optional[art_id.Partitions], time_p_idl: Optional[art_id.Partitions] +) -> Optional[art_id.Partitions]: + if not p_idl and not time_p_idl: + return None + p = {} + if p_idl and p_idl.value: + p.update(p_idl.value) + if time_p_idl and time_p_idl.value: + p.update(time_p_idl.value) + + return art_id.Partitions(value=p) if p else None + + +def partitions_to_idl( + partitions: Optional[Partitions], + time_partition: Optional[TimePartition], + bindings: Optional[typing.List[Artifact]] = None, +) -> Optional[art_id.Partitions]: + partition_idl = None + if partitions: + partition_idl = partitions.to_flyte_idl(bindings) + + time_p_idl = None + if time_partition: + time_p_idl = time_partition.to_flyte_idl(bindings) + + merged = merge_idl_partitions(partition_idl, time_p_idl) + return merged + class Partition(object): - def __init__(self, value: art_id.LabelValue): + def __init__(self, value: Optional[art_id.LabelValue], name: str = None): + self.name = name self.value = value self.reference_artifact: Optional[Artifact] = None @@ -278,9 +336,9 @@ def __init__(self, partitions: Optional[typing.Dict[str, Union[str, art_id.Input if isinstance(v, Partition): self._partitions[k] = v elif isinstance(v, art_id.InputBindingData): - self._partitions[k] = Partition(art_id.LabelValue(input_binding=v)) + self._partitions[k] = Partition(art_id.LabelValue(input_binding=v), name=k) else: - self._partitions[k] = Partition(art_id.LabelValue(static_value=v)) + self._partitions[k] = Partition(art_id.LabelValue(static_value=v), name=k) self.reference_artifact = None @property @@ -298,61 +356,68 @@ def __getattr__(self, item): return self.partitions[item] raise AttributeError(f"Partition {item} not found in {self}") + def get_idl_partitions_for_trigger( + self, + bindings: typing.List[Artifact] = None, + ) -> art_id.Partitions: + p = {} + # First create partition requirements for all the partitions + if self.reference_artifact and self.reference_artifact in bindings: + idx = bindings.index(self.reference_artifact) + triggering_artifact = bindings[idx] + if triggering_artifact.partition_keys: + for k in triggering_artifact.partition_keys: + p[k] = art_id.LabelValue( + triggered_binding=art_id.ArtifactBindingData( + index=idx, + partition_key=k, + ) + ) + + for k, v in self.partitions.items(): + if not v.reference_artifact or ( + v.reference_artifact + and v.reference_artifact is self.reference_artifact + and not v.reference_artifact in bindings + ): + # consider changing condition to just check for static value + p[k] = art_id.LabelValue(static_value=v.value.static_value) + elif v.reference_artifact in bindings: + # This line here is why the PartitionValue object has a name field. + # We might bind to a partition key that's a different name than the k here. + p[k] = art_id.LabelValue( + triggered_binding=art_id.ArtifactBindingData( + index=bindings.index(v.reference_artifact), + partition_key=v.name, + ) + ) + else: + raise ValueError(f"Partition has unhandled reference artifact {v.reference_artifact}") + + return art_id.Partitions(value=p) + def to_flyte_idl( self, - time_partition: Optional[TimePartition], bindings: Optional[typing.List[Artifact]] = None, ) -> Optional[art_id.Partitions]: - if not self.partitions and not time_partition: - return None # This is basically a flag, which indicates that we are serializing this object within the context of a Trigger + # If we are not, then we are just serializing normally if bindings and len(bindings) > 0: - p = {} - if self.partitions: - for k, v in self.partitions.items(): - if not v.reference_artifact or ( - v.reference_artifact - and v.reference_artifact is self.reference_artifact - and not v.reference_artifact in bindings - ): - p[k] = art_id.LabelValue(static_value=v.value) - elif v.reference_artifact in bindings: - p[k] = art_id.LabelValue( - triggered_binding=art_id.ArtifactBindingData( - index=bindings.index(v.reference_artifact), - partition_key=v.name, - ) - ) - else: - raise ValueError(f"Partition has unhandled reference artifact {v.reference_artifact}") - - if time_partition: - if not time_partition.reference_artifact or ( - time_partition.reference_artifact is self.reference_artifact - and not time_partition.reference_artifact in bindings - ): - p[TIME_PARTITION] = art_id.LabelValue(static_value=time_partition.expr) - elif time_partition.reference_artifact in bindings: - transform = None - if time_partition.op and time_partition.other and isinstance(time_partition.other, timedelta): - transform = str(time_partition.op) + isodate.duration_isoformat(time_partition.other) - p[TIME_PARTITION] = art_id.LabelValue( - triggered_binding=art_id.ArtifactBindingData( - index=bindings.index(time_partition.reference_artifact), - partition_key=TIME_PARTITION, - transform=transform, - ) - ) - else: - raise ValueError(f"Partition values has unhandled reference artifact {time_partition}") - return art_id.Partitions(value=p) + return self.get_idl_partitions_for_trigger(bindings) + + if not self.partitions: + return None - # If we are not, then we are just serializing normally pp = {} if self.partitions: for k, v in self.partitions.items(): - pp[k] = v.value - pp.update({TIME_PARTITION: time_partition.value} if time_partition else {}) + if v.value is None: + # This should only happen when serializing for triggers + # Probably indicative of something in the data model that can be fixed + # down the road. + pp[k] = art_id.LabelValue(static_value="") + else: + pp[k] = v.value return art_id.Partitions(value=pp) @@ -439,6 +504,12 @@ def __init__( self.partition_keys = list(partitions.keys()) else: self._partitions = partitions + self.partition_keys = list(partitions.partitions.keys()) + self._partitions.set_reference_artifact(self) + if not partitions and partition_keys: + # this should be the only time where we create Partition objects with None + p = {k: Partition(None, name=k) for k in partition_keys} + self._partitions = Partitions(p) self._partitions.set_reference_artifact(self) self.python_val = python_val self.python_type = python_type @@ -458,22 +529,17 @@ def __call__(self, *args, **kwargs) -> ArtifactIDSpecification: partial_id = ArtifactIDSpecification(self) return partial_id.bind_partitions(*args, **kwargs) - def bind_time_partition(self, p: Union[datetime.datetime, art_id.InputBindingData]) -> ArtifactIDSpecification: - """ - This function should only ever be called in the context of a task or workflow's output, to be - used in an Annotated[] call. The other styles will go through different call functions. - """ - # Can't guarantee the order in which time/non-time partitions are bound so create the helper - # object and invoke the function there. - partial_id = ArtifactIDSpecification(self) - return partial_id.bind_time_partition(p) - @property def partitions(self) -> Optional[Partitions]: return self._partitions @property - def time_partition(self) -> Optional[TimePartition]: + def time_partition(self) -> TimePartition: + if not self.time_partitioned: + raise ValueError(f"Artifact {self.name} is not time partitioned") + if not self._time_partition and self.time_partitioned: + self._time_partition = TimePartition(None) + self._time_partition.reference_artifact = self return self._time_partition def __str__(self): @@ -492,7 +558,7 @@ def __repr__(self): def get( cls, uri: Optional[str], - artifact_id: Optional[idl.ArtifactID], + artifact_id: Optional[art_id.ArtifactID], remote: FlyteRemote, get_details: bool = False, ) -> Optional[Artifact]: @@ -532,16 +598,22 @@ def query( partitions = Partitions(partitions) partitions.reference_artifact = self # only set top level - time_partition = ( - TimePartition(time_partition) if time_partition and not isinstance(time_partition, TimePartition) else None - ) + tp = None + if time_partition: + if isinstance(time_partition, TimePartition): + tp = time_partition + else: + tp = TimePartition(time_partition) + tp.reference_artifact = self + + tp = tp or (self.time_partition if self.time_partitioned else None) aq = ArtifactQuery( artifact=self, name=self.name, project=project or self.project or None, domain=domain or self.domain or None, - time_partition=time_partition or self.time_partition, + time_partition=tp, partitions=partitions or self.partitions, tag=tag or self.tags[0] if self.tags else None, ) @@ -596,7 +668,7 @@ def as_artifact_id(self) -> art_id.ArtifactID: return self.to_flyte_idl().artifact_id def embed_as_query( - self, bindings: typing.List[Artifact], partition: Optional[str], expr: Optional[str] + self, bindings: typing.List[Artifact], partition: Optional[str] = None, expr: Optional[str] = None ) -> art_id.ArtifactQuery: """ This should only be called in the context of a Trigger @@ -608,9 +680,7 @@ def embed_as_query( idx = bindings.index(self) aq = art_id.ArtifactQuery( binding=art_id.ArtifactBindingData( - index=idx, - partition_key=partition, - transform=str(expr) if expr and partition else None, + index=idx, partition_key=partition, transform=str(expr) if expr and partition else None ) ) return aq @@ -621,12 +691,8 @@ def to_flyte_idl(self) -> artifacts_pb2.Artifact: This is here instead of translator because it's in the interface, a relatively simple proto object that's exposed to the user. """ - if self.partitions: - p = self.partitions.to_flyte_idl(self.time_partition) - elif self.time_partition: - p = Partitions(None).to_flyte_idl(self.time_partition) - else: - p = None + p = partitions_to_idl(self.partitions, self.time_partition if self.time_partitioned else None) + return artifacts_pb2.Artifact( artifact_id=art_id.ArtifactID( artifact_key=art_id.ArtifactKey( @@ -651,7 +717,7 @@ def as_create_request(self) -> artifacts_pb2.CreateArtifactRequest: value=self.literal, type=self.literal_type, ) - partitions = self.partitions.to_flyte_idl(self.time_partition) + partitions = partitions_to_idl(self.partitions, self.time_partition) tag = self.tags[0] if self.tags else None return artifacts_pb2.CreateArtifactRequest(artifact_key=ak, spec=spec, partitions=partitions, tag=tag) @@ -681,7 +747,7 @@ def from_flyte_idl(cls, pb2: artifacts_pb2.Artifact) -> Artifact: a._partitions = Partitions( partitions={ - k: Partition(value=v) + k: Partition(value=v, name=k) for k, v in pb2.artifact_id.partitions.value.items() if k != TIME_PARTITION } diff --git a/flytekit/trigger.py b/flytekit/trigger.py index a2eced3494..6623565d84 100644 --- a/flytekit/trigger.py +++ b/flytekit/trigger.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Type, Union import isodate +from flyteidl.core import artifact_id_pb2 as art_id from flyteidl.core import identifier_pb2 as idl from flyteidl.core import interface_pb2 @@ -20,9 +21,9 @@ class Trigger(TrackedInstance): trigger_on=[dailyArtifact, hourlyArtifact], inputs={ "today_upstream": dailyArtifact, # this means: use the matched artifact - "yesterday_upstream": dailyArtifact.time_partition - timedelta(days=1), - # this means, query for yesterday's artifact because there's math. - "other_daily_upstream": hourlyArtifact, # this means: use the matched hourly artifact + "yesterday_upstream": dailyArtifact.query(time_partition=dailyArtifact.time_partition - timedelta(days=1)), + # this means: use the matched hourly artifact + "other_daily_upstream": hourlyArtifact.query(partitions={"region": "LAX"}), "region": "SEA", # static value that will be passed as input "other_artifact": UnrelatedArtifact.query(time_partition=dailyArtifact.time_partition - timedelta(days=1)), "other_artifact_2": UnrelatedArtifact.query(time_partition=hourlyArtifact.time_partition.truncate_to_day()), @@ -68,7 +69,7 @@ def get_parameter_map( for k, v in self.inputs.items(): var = input_typed_interface[k].to_flyte_idl() if isinstance(v, Artifact): - aq = v.embed_as_query(self.triggers, None, None) + aq = v.embed_as_query(self.triggers) p = interface_pb2.Parameter(var=var, artifact_query=aq) elif isinstance(v, ArtifactQuery): p = interface_pb2.Parameter(var=var, artifact_query=v.to_flyte_idl(self.triggers)) @@ -92,20 +93,19 @@ def get_parameter_map( pm[k] = p return interface_pb2.ParameterMap(parameters=pm) - def to_flyte_idl(self) -> idl.Trigger: + def to_flyte_idl(self) -> art_id.Trigger: try: name = f"{self.instantiated_in}.{self.lhs}" - except Exception: + except Exception: # noqa broad for now given the changing nature of the tracker implementation. import random from uuid import UUID name = "trigger" + UUID(int=random.getrandbits(128)).hex # project/domain will be empty - to be bound later at registration time. - artifact_ids = [a.to_flyte_idl().artifact_id for a in self.triggers] - return idl.Trigger( + return art_id.Trigger( trigger_id=idl.Identifier( resource_type=idl.ResourceType.LAUNCH_PLAN, name=name, @@ -139,4 +139,4 @@ def __call__(self, *args, **kwargs): self_idl = self.to_flyte_idl() trigger_lp._additional_metadata = self_idl - return trigger_lp + return entity diff --git a/tests/flytekit/unit/core/test_artifacts.py b/tests/flytekit/unit/core/test_artifacts.py index 3171dfbdfb..c83e5cb035 100644 --- a/tests/flytekit/unit/core/test_artifacts.py +++ b/tests/flytekit/unit/core/test_artifacts.py @@ -4,17 +4,14 @@ import pandas as pd import pytest -from flyteidl.artifact import artifacts_pb2 -from flyteidl.core import identifier_pb2 from typing_extensions import Annotated -from flytekit.configuration import Config, Image, ImageConfig, SerializationSettings +from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.artifact import Artifact, Inputs from flytekit.core.context_manager import FlyteContextManager from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task from flytekit.core.workflow import workflow -from flytekit.remote.remote import FlyteRemote from flytekit.tools.translator import get_serializable from flytekit.types.structured.structured_dataset import StructuredDataset @@ -39,7 +36,7 @@ def test_basic_option_a_rev(): @task def t1( b_value: str, dt: datetime.datetime - ) -> Annotated[pd.DataFrame, a1_t_ab.bind_time_partition(Inputs.dt)(b=Inputs.b_value, a="manual")]: + ) -> Annotated[pd.DataFrame, a1_t_ab(time_partition=Inputs.dt, b=Inputs.b_value, a="manual")]: df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) return df @@ -64,7 +61,7 @@ def test_basic_option_a(): @task def t1( b_value: str, dt: datetime.datetime - ) -> Annotated[pd.DataFrame, a1_t_ab(b=Inputs.b_value, a="manual").bind_time_partition(Inputs.dt)]: + ) -> Annotated[pd.DataFrame, a1_t_ab(b=Inputs.b_value, a="manual", time_partition=Inputs.dt)]: df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) return df @@ -75,6 +72,8 @@ def t1( assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.name == "my_data" assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.project == "" + +def test_basic_option_a2(): a2_ab = Artifact(name="my_data2", partition_keys=["a", "b"]) with pytest.raises(ValueError): @@ -94,6 +93,8 @@ def t2(b_value: str) -> Annotated[pd.DataFrame, a2_ab(a=Inputs.b_value, b="manua assert t2_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.name == "my_data2" assert t2_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.project == "" + +def test_basic_option_a3(): a3 = Artifact(name="my_data3") @task @@ -129,30 +130,6 @@ def wf() -> Annotated[CustomReturn, wf_alias]: assert tag.value.static_value == "my_v0.1.0" -def test_artifact_as_promise_query(): - # when artifact is partially specified, can be used as a query input - wf_artifact = Artifact(project="project1", domain="dev", name="wf_artifact", tags=["my_v0.1.0"]) - - @task - def t1(a: CustomReturn) -> CustomReturn: - print(a) - return CustomReturn({"name": ["Tom", "Joseph"], "age": [20, 22]}) - - @workflow - def wf(a: CustomReturn = wf_artifact.query()): - u = t1(a=a) - return u - - ctx = FlyteContextManager.current_context() - lp = LaunchPlan.get_default_launch_plan(ctx, wf) - entities = OrderedDict() - spec = get_serializable(entities, serialization_settings, lp) - assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_tag.artifact_key.project == "project1" - assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_tag.artifact_key.domain == "dev" - assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_tag.artifact_key.name == "wf_artifact" - assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_tag.value.static_value == "my_v0.1.0" - - def test_query_basic(): aa = Artifact( name="ride_count_data", @@ -187,6 +164,30 @@ def test_not_specified_behavior(): assert aq.artifact_id.HasField("partitions") is False +def test_artifact_as_promise_query(): + # when artifact is partially specified, can be used as a query input + wf_artifact = Artifact(project="project1", domain="dev", name="wf_artifact", tags=["my_v0.1.0"]) + + @task + def t1(a: CustomReturn) -> CustomReturn: + print(a) + return CustomReturn({"name": ["Tom", "Joseph"], "age": [20, 22]}) + + @workflow + def wf(a: CustomReturn = wf_artifact.query()): + u = t1(a=a) + return u + + ctx = FlyteContextManager.current_context() + lp = LaunchPlan.get_default_launch_plan(ctx, wf) + entities = OrderedDict() + spec = get_serializable(entities, serialization_settings, lp) + assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_tag.artifact_key.project == "project1" + assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_tag.artifact_key.domain == "dev" + assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_tag.artifact_key.name == "wf_artifact" + assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_tag.value.static_value == "my_v0.1.0" + + def test_artifact_as_promise(): # when the full artifact is specified, the artifact should be bindable as a literal wf_artifact = Artifact(project="pro", domain="dom", name="key", version="v0.1.0", partitions={"region": "LAX"}) @@ -197,14 +198,15 @@ def t1(a: CustomReturn) -> CustomReturn: return CustomReturn({"name": ["Tom", "Joseph"], "age": [20, 22]}) @workflow - def wf(a: CustomReturn = wf_artifact): + def wf2(a: CustomReturn = wf_artifact): u = t1(a=a) return u ctx = FlyteContextManager.current_context() - lp = LaunchPlan.get_default_launch_plan(ctx, wf) + lp = LaunchPlan.get_default_launch_plan(ctx, wf2) entities = OrderedDict() spec = get_serializable(entities, serialization_settings, lp) + x = spec.spec.default_inputs.parameters["a"] assert spec.spec.default_inputs.parameters["a"].artifact_id.artifact_key.project == "pro" assert spec.spec.default_inputs.parameters["a"].artifact_id.artifact_key.domain == "dom" assert spec.spec.default_inputs.parameters["a"].artifact_id.artifact_key.name == "key" @@ -212,94 +214,3 @@ def wf(a: CustomReturn = wf_artifact): aq = wf_artifact.query().to_flyte_idl() assert aq.artifact_id.HasField("partitions") is True assert aq.artifact_id.partitions.value["region"].static_value == "LAX" - - -@pytest.mark.sandbox_test -def test_create_an_artifact33_locally(): - import grpc - from flyteidl.artifact.artifacts_pb2_grpc import ArtifactRegistryStub - - local_artifact_channel = grpc.insecure_channel("127.0.0.1:50051") - stub = ArtifactRegistryStub(local_artifact_channel) - ak = identifier_pb2.ArtifactKey(project="flytesnacks", domain="development", name="f3bea14ee52f8409eb5b/n0/0/o/o0") - ai = identifier_pb2.ArtifactID(artifact_key=ak) - req = artifacts_pb2.GetArtifactRequest(query=identifier_pb2.ArtifactQuery(artifact_id=ai)) - x = stub.GetArtifact(req) - print(x) - - -@pytest.mark.sandbox_test -def test_create_an_artifact_locally(): - df = pd.DataFrame({"Name": ["Mary", "Jane"], "Age": [22, 23]}) - # a = Artifact.initialize(python_val=df, python_type=pd.DataFrame, name="flyteorg.test.yt.test1", - # aliases=["v0.1.0"]) - a = Artifact.initialize(python_val=42, python_type=int, name="flyteorg.test.yt.test1", aliases=["v0.1.6"]) - r = FlyteRemote( - Config.auto(config_file="/Users/ytong/.flyte/local_admin.yaml"), - default_project="flytesnacks", - default_domain="development", - ) - r.create_artifact(a) - print(a) - - -@pytest.mark.sandbox_test -def test_pull_artifact_and_use_to_launch(): - """ - df_artifact = Artifact("flyte://a1") - remote.execute(wf, inputs={"a": df_artifact}) - Artifact.i - """ - r = FlyteRemote( - Config.auto(config_file="/Users/ytong/.flyte/local_admin.yaml"), - default_project="flytesnacks", - default_domain="development", - ) - wf = r.fetch_workflow( - "flytesnacks", "development", "cookbook.core.flyte_basics.basic_workflow.my_wf", "KVBL7dDsBdtaqjUgZIJzdQ==" - ) - - # Fetch artifact and execute workflow with it - a = r.get_artifact(uri="flyte://av0.1/flytesnacks/development/flyteorg.test.yt.test1:v0.1.6") - print(a) - r.execute(wf, inputs={"a": a}) - - # Just specify the artifact components - a = Artifact(project="flytesnacks", domain="development", suffix="7438595e5c0e63613dc8df41dac5ee40") - - -@pytest.mark.sandbox_test -def test_artifact_query(): - str_artifact = Artifact(name="flyteorg.test.yt.teststr", aliases=["latest"]) - - @task - def base_t1() -> Annotated[str, str_artifact]: - return "hello world" - - @workflow - def base_wf(): - base_t1() - - @task - def printer(a: str): - print(f"Task 2: {a}") - - @workflow - def user_wf(a: str = str_artifact.as_query()): - printer(a=a) - - -@pytest.mark.sandbox_test -def test_get_and_run(): - r = FlyteRemote( - Config.auto(config_file="/Users/ytong/.flyte/local_admin.yaml"), - default_project="flytesnacks", - default_domain="development", - ) - a = r.get_artifact(uri="flyte://av0.1/flytesnacks/development/a5zk94pb6lgg5v7l7zw8/n0/0/o:o0") - print(a) - - wf = r.fetch_workflow( - "flytesnacks", "development", "artifact_examples.consume_a_dataframe", "DZsIW4WlZPqKwJyRQ24SGw==" - ) - r.execute(wf, inputs={"df": a}) diff --git a/tests/flytekit/unit/core/test_artifacts_sandbox.py b/tests/flytekit/unit/core/test_artifacts_sandbox.py new file mode 100644 index 0000000000..a93dd34d33 --- /dev/null +++ b/tests/flytekit/unit/core/test_artifacts_sandbox.py @@ -0,0 +1,114 @@ +import pandas as pd +import pytest +from flyteidl.artifact import artifacts_pb2 +from flyteidl.core import identifier_pb2 +from typing_extensions import Annotated + +from flytekit.configuration import Config, Image, ImageConfig, SerializationSettings +from flytekit.core.artifact import Artifact +from flytekit.core.task import task +from flytekit.core.workflow import workflow +from flytekit.remote.remote import FlyteRemote + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +# These test are not updated yet + + +@pytest.mark.sandbox_test +def test_create_an_artifact33_locally(): + import grpc + from flyteidl.artifact.artifacts_pb2_grpc import ArtifactRegistryStub + + local_artifact_channel = grpc.insecure_channel("127.0.0.1:50051") + stub = ArtifactRegistryStub(local_artifact_channel) + ak = identifier_pb2.ArtifactKey(project="flytesnacks", domain="development", name="f3bea14ee52f8409eb5b/n0/0/o/o0") + ai = identifier_pb2.ArtifactID(artifact_key=ak) + req = artifacts_pb2.GetArtifactRequest(query=identifier_pb2.ArtifactQuery(artifact_id=ai)) + x = stub.GetArtifact(req) + print(x) + + +@pytest.mark.sandbox_test +def test_create_an_artifact_locally(): + df = pd.DataFrame({"Name": ["Mary", "Jane"], "Age": [22, 23]}) + # a = Artifact.initialize(python_val=df, python_type=pd.DataFrame, name="flyteorg.test.yt.test1", + # aliases=["v0.1.0"]) + a = Artifact.initialize(python_val=42, python_type=int, name="flyteorg.test.yt.test1", aliases=["v0.1.6"]) + r = FlyteRemote( + Config.auto(config_file="/Users/ytong/.flyte/local_admin.yaml"), + default_project="flytesnacks", + default_domain="development", + ) + r.create_artifact(a) + print(a) + + +@pytest.mark.sandbox_test +def test_pull_artifact_and_use_to_launch(): + """ + df_artifact = Artifact("flyte://a1") + remote.execute(wf, inputs={"a": df_artifact}) + Artifact.i + """ + r = FlyteRemote( + Config.auto(config_file="/Users/ytong/.flyte/local_admin.yaml"), + default_project="flytesnacks", + default_domain="development", + ) + wf = r.fetch_workflow( + "flytesnacks", "development", "cookbook.core.flyte_basics.basic_workflow.my_wf", "KVBL7dDsBdtaqjUgZIJzdQ==" + ) + + # Fetch artifact and execute workflow with it + a = r.get_artifact(uri="flyte://av0.1/flytesnacks/development/flyteorg.test.yt.test1:v0.1.6") + print(a) + r.execute(wf, inputs={"a": a}) + + # Just specify the artifact components + a = Artifact(project="flytesnacks", domain="development", suffix="7438595e5c0e63613dc8df41dac5ee40") + + +@pytest.mark.sandbox_test +def test_artifact_query(): + str_artifact = Artifact(name="flyteorg.test.yt.teststr", aliases=["latest"]) + + @task + def base_t1() -> Annotated[str, str_artifact]: + return "hello world" + + @workflow + def base_wf(): + base_t1() + + @task + def printer(a: str): + print(f"Task 2: {a}") + + @workflow + def user_wf(a: str = str_artifact.as_query()): + printer(a=a) + + +@pytest.mark.sandbox_test +def test_get_and_run(): + r = FlyteRemote( + Config.auto(config_file="/Users/ytong/.flyte/local_admin.yaml"), + default_project="flytesnacks", + default_domain="development", + ) + a = r.get_artifact(uri="flyte://av0.1/flytesnacks/development/a5zk94pb6lgg5v7l7zw8/n0/0/o:o0") + print(a) + + wf = r.fetch_workflow( + "flytesnacks", "development", "artifact_examples.consume_a_dataframe", "DZsIW4WlZPqKwJyRQ24SGw==" + ) + r.execute(wf, inputs={"df": a}) diff --git a/tests/flytekit/unit/core/test_triggers.py b/tests/flytekit/unit/core/test_triggers.py index 2e568eae75..f25920d649 100644 --- a/tests/flytekit/unit/core/test_triggers.py +++ b/tests/flytekit/unit/core/test_triggers.py @@ -1,40 +1,69 @@ from datetime import timedelta -from flyteidl.core import identifier_pb2 as idl +from flyteidl.core import artifact_id_pb2 as art_id from flyteidl.core import literals_pb2 from typing_extensions import Annotated -from flytekit.core.artifact import Artifact, Inputs +from flytekit.core.artifact import Artifact from flytekit.core.workflow import workflow from flytekit.trigger import Trigger +def test_basic_11(): + # This test would translate to + # Trigger(trigger_on=[hourlyArtifact], + # inputs={"x": hourlyArtifact}) + hourlyArtifact = Artifact( + name="hourly_artifact", + time_partitioned=True, + partition_keys=["region"], + ) + aq_idl = hourlyArtifact.embed_as_query([hourlyArtifact]) + assert aq_idl.HasField("binding") + assert aq_idl.binding.index == 0 + + def test_basic_1(): + # This test would translate to + # Trigger(trigger_on=[hourlyArtifact], + # inputs={"x": hourlyArtifact.query(region="LAX")}) + # note since hourlyArtifact is time partitioned, and it has one other partition key called some_dim, + # these should be bound to the trigger, and region should be a static value. hourlyArtifact = Artifact( - name="hourly_artifact", time_partition="{{ inputs.time_input }}", partitions={"region": "{{ inputs.region }} "} + name="hourly_artifact", + time_partitioned=True, + partition_keys=["region", "some_dim"], ) aq = hourlyArtifact.query(partitions={"region": "LAX"}) aq_idl = aq.to_flyte_idl([hourlyArtifact]) - assert aq_idl.artifact_id.partitions.value["ds"].binding.index == 0 + assert aq_idl.artifact_id.partitions.value["ds"].HasField("triggered_binding") + assert aq_idl.artifact_id.partitions.value["ds"].triggered_binding.index == 0 + assert aq_idl.artifact_id.partitions.value["some_dim"].HasField("triggered_binding") + assert aq_idl.artifact_id.partitions.value["some_dim"].triggered_binding.index == 0 + assert aq_idl.artifact_id.partitions.value["region"].static_value == "LAX" def test_basic_2(): - dailyArtifact = Artifact(name="daily_artifact", time_partition="{{ inputs.time_input }}") + dailyArtifact = Artifact(name="daily_artifact", time_partitioned=True) aq = dailyArtifact.query(time_partition=dailyArtifact.time_partition - timedelta(days=1)) aq_idl = aq.to_flyte_idl([dailyArtifact]) - assert aq_idl.artifact_id.partitions.value["ds"].binding.partition_key == "ds" - assert aq_idl.artifact_id.partitions.value["ds"].binding.transform is not None + x = aq_idl.artifact_id.partitions.value + assert aq_idl.artifact_id.partitions.value["ds"].triggered_binding.index == 0 + assert aq_idl.artifact_id.partitions.value["ds"].triggered_binding.partition_key == "ds" + assert aq_idl.artifact_id.partitions.value["ds"].triggered_binding.transform is not None def test_big_trigger(): - dailyArtifact = Artifact(name="daily_artifact", time_partition="{{ inputs.time_input }}") + dailyArtifact = Artifact(name="daily_artifact", time_partitioned=True) hourlyArtifact = Artifact( - name="hourly_artifact", time_partition="{{ inputs.time_input }}", partitions={"region": "{{ inputs.region }} "} + name="hourly_artifact", + time_partitioned=True, + partition_keys=["region"], ) - UnrelatedArtifact = Artifact(name="unrelated_artifact", time_partition="{{ inputs.date }}") - UnrelatedTwo = Artifact(name="unrelated_two", partitions={"region": "{{ inputs.region }} "}) + UnrelatedArtifact = Artifact(name="unrelated_artifact", time_partitioned=True) + UnrelatedTwo = Artifact(name="unrelated_two", partition_keys=["region"]) t = Trigger( # store these locally. @@ -49,7 +78,7 @@ def test_big_trigger(): "region": "SEA", # static value that will be passed as input "other_artifact": UnrelatedArtifact.query(time_partition=dailyArtifact.time_partition), "other_artifact_2": UnrelatedArtifact.query(time_partition=hourlyArtifact.time_partition.truncate_to_day()), - "other_artifact_3": UnrelatedTwo.query(partitions={"region": hourlyArtifact.partitions.region}), + "other_artifact_3": UnrelatedTwo.query(partitions={"rgg": hourlyArtifact.partitions.region}), }, ) @@ -67,34 +96,34 @@ def my_workflow( pm = t.get_parameter_map(my_workflow.python_interface.inputs, my_workflow.interface.inputs) - assert pm.parameters["today_upstream"].artifact_query == idl.ArtifactQuery( - binding=idl.ArtifactBindingData( + assert pm.parameters["today_upstream"].artifact_query == art_id.ArtifactQuery( + binding=art_id.ArtifactBindingData( index=0, ), ) assert not pm.parameters["today_upstream"].artifact_query.binding.partition_key assert not pm.parameters["today_upstream"].artifact_query.binding.transform - assert pm.parameters["yesterday_upstream"].artifact_query == idl.ArtifactQuery( - artifact_id=idl.ArtifactID( - artifact_key=idl.ArtifactKey(project=None, domain=None, name="daily_artifact"), - partitions=idl.Partitions( + assert pm.parameters["yesterday_upstream"].artifact_query == art_id.ArtifactQuery( + artifact_id=art_id.ArtifactID( + artifact_key=art_id.ArtifactKey(project=None, domain=None, name="daily_artifact"), + partitions=art_id.Partitions( value={ - "ds": idl.PartitionValue( - binding=idl.ArtifactBindingData(index=0, partition_key="ds", transform="-P1D") + "ds": art_id.LabelValue( + triggered_binding=art_id.ArtifactBindingData(index=0, partition_key="ds", transform="-P1D") ), } ), ), ) - assert pm.parameters["other_daily_upstream"].artifact_query == idl.ArtifactQuery( - artifact_id=idl.ArtifactID( - artifact_key=idl.ArtifactKey(project=None, domain=None, name="hourly_artifact"), - partitions=idl.Partitions( + assert pm.parameters["other_daily_upstream"].artifact_query == art_id.ArtifactQuery( + artifact_id=art_id.ArtifactID( + artifact_key=art_id.ArtifactKey(project=None, domain=None, name="hourly_artifact"), + partitions=art_id.Partitions( value={ - "ds": idl.PartitionValue(binding=idl.ArtifactBindingData(index=1, partition_key="ds")), - "region": idl.PartitionValue(static_value="LAX"), + "ds": art_id.LabelValue(triggered_binding=art_id.ArtifactBindingData(index=1, partition_key="ds")), + "region": art_id.LabelValue(static_value="LAX"), } ), ), @@ -104,39 +133,45 @@ def my_workflow( scalar=literals_pb2.Scalar(primitive=literals_pb2.Primitive(string_value="SEA")) ) - assert pm.parameters["other_artifact"].artifact_query == idl.ArtifactQuery( - artifact_id=idl.ArtifactID( - artifact_key=idl.ArtifactKey(project=None, domain=None, name="unrelated_artifact"), - partitions=idl.Partitions( + assert pm.parameters["other_artifact"].artifact_query == art_id.ArtifactQuery( + artifact_id=art_id.ArtifactID( + artifact_key=art_id.ArtifactKey(project=None, domain=None, name="unrelated_artifact"), + partitions=art_id.Partitions( value={ - "ds": idl.PartitionValue(binding=idl.ArtifactBindingData(index=0, partition_key="ds")), + "ds": art_id.LabelValue(triggered_binding=art_id.ArtifactBindingData(index=0, partition_key="ds")), } ), ) ) - assert pm.parameters["other_artifact_2"].artifact_query == idl.ArtifactQuery( - artifact_id=idl.ArtifactID( - artifact_key=idl.ArtifactKey(project=None, domain=None, name="unrelated_artifact"), - partitions=idl.Partitions( + assert pm.parameters["other_artifact_2"].artifact_query == art_id.ArtifactQuery( + artifact_id=art_id.ArtifactID( + artifact_key=art_id.ArtifactKey(project=None, domain=None, name="unrelated_artifact"), + partitions=art_id.Partitions( value={ - "ds": idl.PartitionValue(binding=idl.ArtifactBindingData(index=1, partition_key="ds")), + "ds": art_id.LabelValue(triggered_binding=art_id.ArtifactBindingData(index=1, partition_key="ds")), } ), ) ) - assert pm.parameters["other_artifact_3"].artifact_query == idl.ArtifactQuery( - artifact_id=idl.ArtifactID( - artifact_key=idl.ArtifactKey(project=None, domain=None, name="unrelated_two"), - partitions=idl.Partitions( + assert pm.parameters["other_artifact_3"].artifact_query == art_id.ArtifactQuery( + artifact_id=art_id.ArtifactID( + artifact_key=art_id.ArtifactKey(project=None, domain=None, name="unrelated_two"), + partitions=art_id.Partitions( value={ - "region": idl.PartitionValue(binding=idl.ArtifactBindingData(index=1, partition_key="region")), + "rgg": art_id.LabelValue( + triggered_binding=art_id.ArtifactBindingData(index=1, partition_key="region") + ), } ), ) ) + idl_t = t.to_flyte_idl() + assert idl_t.triggers[0].partitions.value["ds"] is not None + assert idl_t.triggers[1].partitions.value["ds"] is not None + # Test calling it to create the LaunchPlan object which adds to the global context @t @workflow @@ -153,7 +188,7 @@ def tst_wf( def test_partition_only(): - dailyArtifact = Artifact(name="daily_artifact", time_partition="{{ inputs.time_input }}") + dailyArtifact = Artifact(name="daily_artifact", time_partitioned=True) t = Trigger( # store these locally. @@ -170,8 +205,8 @@ def tst_wf( ... pm = t.get_parameter_map(tst_wf.python_interface.inputs, tst_wf.interface.inputs) - assert pm.parameters["today_upstream"].artifact_query == idl.ArtifactQuery( - binding=idl.ArtifactBindingData( + assert pm.parameters["today_upstream"].artifact_query == art_id.ArtifactQuery( + binding=art_id.ArtifactBindingData( index=0, partition_key="ds", transform="-P1D",