Skip to content

Commit

Permalink
use separate time partition in idl change flyteorg/flyte#4737
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor committed Jan 18, 2024
1 parent 685403d commit 692b78c
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 135 deletions.
150 changes: 84 additions & 66 deletions flytekit/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from flyteidl.artifact import artifacts_pb2
from flyteidl.core import artifact_id_pb2 as art_id
from flyteidl.core.identifier_pb2 import TaskExecutionIdentifier, WorkflowExecutionIdentifier
from flyteidl.core.literals_pb2 import Literal
from flyteidl.core.types_pb2 import LiteralType
from google.protobuf.timestamp_pb2 import Timestamp

from flytekit.loggers import logger
from flytekit.models.literals import Literal
Expand Down Expand Up @@ -51,12 +50,10 @@ class ArtifactIDSpecification(object):
having a pointer to the main artifact.
"""

def __init__(
self, a: Artifact, partitions: Optional[Partitions] = None, time_partition: Optional[TimePartition] = None
):
def __init__(self, a: Artifact):
self.artifact = a
self.partitions = partitions
self.time_partition = time_partition
self.partitions: Optional[Partitions] = None
self.time_partition: Optional[TimePartition] = None

# todo: add time partition arg hint
def __call__(self, *args, **kwargs):
Expand All @@ -72,7 +69,9 @@ def bind_partitions(self, *args, **kwargs) -> ArtifactIDSpecification:
raise ValueError("Cannot bind time partition to non-time partitioned artifact")
p = kwargs[TIME_PARTITION_KWARG]
if isinstance(p, datetime.datetime):
self.time_partition = TimePartition(value=art_id.LabelValue(static_value=f"{p}"))
t = Timestamp()
t.FromDatetime(p)
self.time_partition = TimePartition(value=art_id.LabelValue(time_value=t))
elif isinstance(p, art_id.InputBindingData):
self.time_partition = TimePartition(value=art_id.LabelValue(input_binding=p))
else:
Expand Down Expand Up @@ -102,20 +101,30 @@ def to_partial_artifact_id(self) -> art_id.ArtifactID:
# This function should only be called by transform_variable_map
artifact_id = self.artifact.to_flyte_idl().artifact_id
# Use the partitions from this object, but replacement is not allowed by protobuf, so generate new object
p = partitions_to_idl(self.partitions, self.time_partition)
p = partitions_to_idl(self.partitions)
tp = None
if self.artifact.time_partitioned:
if not self.time_partition:
raise ValueError(
f"Artifact {artifact_id.artifact_key} requires a time partition, but it hasn't been bound."
)
tp = self.time_partition.to_flyte_idl()

if self.artifact.partition_keys:
required = len(self.artifact.partition_keys)
required += 1 if self.artifact.time_partitioned else 0
# required += 1 if self.artifact.time_partitioned else 0
fulfilled = len(p.value) if p else 0
if required != fulfilled:
raise ValueError(
f"Artifact {artifact_id.artifact_key} requires {required} partitions, but only {fulfilled} are bound."
f"Artifact {artifact_id.artifact_key} requires {required} partitions, but only {fulfilled} are "
f"bound."
)
artifact_id = art_id.ArtifactID(
artifact_key=artifact_id.artifact_key,
partitions=p,
version=artifact_id.version,
time_partition=tp,
version=artifact_id.version, # this should almost never be set since setting it
# hardcodes the query to one version
)
return artifact_id

Expand Down Expand Up @@ -155,9 +164,7 @@ def __init__(
tag: Optional[str] = None,
):
if not name:
raise ValueError(f"Cannot create query without name")
if partitions and partitions.partitions and TIME_PARTITION in partitions.partitions:
raise ValueError(f"Cannot use 'ds' as a partition name, just use time partition")
raise ValueError("Cannot create query without name")

# So normally, if you just do MyData.query(partitions={"region": "{{ inputs.region }}"}), it will just
# use the input value to fill in the partition. But if you do
Expand Down Expand Up @@ -203,11 +210,15 @@ def to_flyte_idl(
)
return aq

p = partitions_to_idl(self.partitions, self.time_partition, bindings)
p = partitions_to_idl(self.partitions, bindings)
tp = None
if self.time_partition:
tp = self.time_partition.to_flyte_idl(bindings)

i = art_id.ArtifactID(
artifact_key=ak,
partitions=p,
time_partition=tp,
)

aq = art_id.ArtifactQuery(
Expand All @@ -232,13 +243,15 @@ def __init__(
other: Optional[timedelta] = None,
):
if isinstance(value, str):
value = art_id.LabelValue(static_value=value)
raise ValueError(f"value to a time partition shouldn't be a str {value}")
elif isinstance(value, datetime.datetime):
value = art_id.LabelValue(static_value=f"{value}")
t = Timestamp()
t.FromDatetime(value)
value = art_id.LabelValue(time_value=t)
elif isinstance(value, art_id.InputBindingData):
value = art_id.LabelValue(input_binding=value)
# else should already be a LabelValue or None
self.value = value
self.value: art_id.LabelValue = value
self.op = op
self.other = other
self.reference_artifact: Optional[Artifact] = None
Expand All @@ -253,15 +266,11 @@ def __sub__(self, other: timedelta) -> TimePartition:
tp.reference_artifact = self.reference_artifact
return tp

def truncate_to_day(self):
# raise NotImplementedError("Not implemented yet")
return self

def get_idl_partitions_for_trigger(self, bindings: typing.List[Artifact]) -> art_id.Partitions:
def get_idl_partitions_for_trigger(self, bindings: typing.List[Artifact]) -> art_id.TimePartition:
if not self.reference_artifact or (self.reference_artifact and self.reference_artifact not in bindings):
# basically if there's no reference artifact, or if the reference artifact isn't
# in the list of triggers, then treat it like normal.
return art_id.Partitions(value={TIME_PARTITION: self.value})
return art_id.TimePartition(value=self.value)
elif self.reference_artifact in bindings:
idx = bindings.index(self.reference_artifact)
transform = None
Expand All @@ -270,55 +279,35 @@ def get_idl_partitions_for_trigger(self, bindings: typing.List[Artifact]) -> art
lv = art_id.LabelValue(
triggered_binding=art_id.ArtifactBindingData(
index=idx,
partition_key=TIME_PARTITION,
bind_to_time_partition=True,
transform=transform,
)
)
return art_id.Partitions(value={TIME_PARTITION: lv})
return art_id.TimePartition(value=lv)
# investigate if this happens, if not, remove. else
logger.warning(f"Investigate - time partition in trigger with unhandled reference artifact {self}")
return art_id.Partitions(value={TIME_PARTITION: self.value})
raise ValueError("Time partition reference artifact not found in ")
# return art_id.Partitions(value={TIME_PARTITION: self.value})

def to_flyte_idl(self, bindings: Optional[typing.List[Artifact]] = None) -> Optional[art_id.Partitions]:
def to_flyte_idl(self, bindings: Optional[typing.List[Artifact]] = None) -> Optional[art_id.TimePartition]:
if bindings and len(bindings) > 0:
return self.get_idl_partitions_for_trigger(bindings)

if not self.value:
# This is only for triggers - the backend needs to know of the existence of a time partition
return art_id.Partitions(value={TIME_PARTITION: art_id.LabelValue(static_value="")})

return art_id.Partitions(value={TIME_PARTITION: self.value})

return art_id.TimePartition()

def merge_idl_partitions(
p_idl: Optional[art_id.Partitions], time_p_idl: Optional[art_id.Partitions]
) -> Optional[art_id.Partitions]:
if not p_idl and not time_p_idl:
return None
p = {}
if p_idl and p_idl.value:
p.update(p_idl.value)
if time_p_idl and time_p_idl.value:
p.update(time_p_idl.value)

return art_id.Partitions(value=p) if p else None
return art_id.TimePartition(value=self.value)


def partitions_to_idl(
partitions: Optional[Partitions],
time_partition: Optional[TimePartition],
bindings: Optional[typing.List[Artifact]] = None,
) -> Optional[art_id.Partitions]:
partition_idl = None
if partitions:
partition_idl = partitions.to_flyte_idl(bindings)

time_p_idl = None
if time_partition:
time_p_idl = time_partition.to_flyte_idl(bindings)
return partitions.to_flyte_idl(bindings)

merged = merge_idl_partitions(partition_idl, time_p_idl)
return merged
return None


class Partition(object):
Expand Down Expand Up @@ -378,7 +367,7 @@ def get_idl_partitions_for_trigger(
if not v.reference_artifact or (
v.reference_artifact
and v.reference_artifact is self.reference_artifact
and not v.reference_artifact in bindings
and v.reference_artifact not in bindings
):
# consider changing condition to just check for static value
p[k] = art_id.LabelValue(static_value=v.value.static_value)
Expand Down Expand Up @@ -543,12 +532,16 @@ def time_partition(self) -> TimePartition:
return self._time_partition

def __str__(self):
tp_str = f" time partition={self.time_partition}\n" if self.time_partitioned else ""
return (
f"Artifact: project={self.project}, domain={self.domain}, name={self.name}, version={self.version}\n"
f" name={self.name}\n"
f" partitions={self.partitions}\n"
f"{tp_str}"
f" tags={self.tags}\n"
f" literal_type={self.literal_type}, literal={self.literal})"
f" literal_type="
f"{self.literal_type}, "
f"literal={self.literal})"
)

def __repr__(self):
Expand Down Expand Up @@ -668,21 +661,30 @@ def as_artifact_id(self) -> art_id.ArtifactID:
return self.to_flyte_idl().artifact_id

def embed_as_query(
self, bindings: typing.List[Artifact], partition: Optional[str] = None, expr: Optional[str] = None
self,
bindings: typing.List[Artifact],
partition: Optional[str] = None,
bind_to_time_partition: Optional[bool] = None,
expr: Optional[str] = None,
) -> art_id.ArtifactQuery:
"""
This should only be called in the context of a Trigger
:param bindings: The list of artifacts in trigger_on
:param partition: Can embed a time partition
:param bind_to_time_partition: Set to true if you want to bind to a time partition
:param expr: Only valid if there's a time partition.
"""
# Find self in the list, raises ValueError if not there.
idx = bindings.index(self)
aq = art_id.ArtifactQuery(
binding=art_id.ArtifactBindingData(
index=idx, partition_key=partition, transform=str(expr) if expr and partition else None
index=idx,
partition_key=partition,
bind_to_time_partition=bind_to_time_partition,
transform=str(expr) if expr and (partition or bind_to_time_partition) else None,
)
)

return aq

def to_flyte_idl(self) -> artifacts_pb2.Artifact:
Expand All @@ -691,7 +693,8 @@ def to_flyte_idl(self) -> artifacts_pb2.Artifact:
This is here instead of translator because it's in the interface, a relatively simple proto object
that's exposed to the user.
"""
p = partitions_to_idl(self.partitions, self.time_partition if self.time_partitioned else None)
p = partitions_to_idl(self.partitions)
tp = self.time_partition.to_flyte_idl() if self.time_partitioned else None

return artifacts_pb2.Artifact(
artifact_id=art_id.ArtifactID(
Expand All @@ -702,6 +705,7 @@ def to_flyte_idl(self) -> artifacts_pb2.Artifact:
),
version=self.version,
partitions=p,
time_partition=tp,
),
spec=artifacts_pb2.ArtifactSpec(),
tags=self.tags,
Expand All @@ -717,9 +721,18 @@ def as_create_request(self) -> artifacts_pb2.CreateArtifactRequest:
value=self.literal,
type=self.literal_type,
)
partitions = partitions_to_idl(self.partitions, self.time_partition)
tag = self.tags[0] if self.tags else None
return artifacts_pb2.CreateArtifactRequest(artifact_key=ak, spec=spec, partitions=partitions, tag=tag)
partitions = partitions_to_idl(self.partitions)

tp = None
if self._time_partition:
tv = self.time_partition.value.time_value
if not tv:
raise Exception("missing time value")
tp = self.time_partition.value.time_value

return artifacts_pb2.CreateArtifactRequest(
artifact_key=ak, spec=spec, partitions=partitions, time_partition_value=tp
)

@classmethod
def from_flyte_idl(cls, pb2: artifacts_pb2.Artifact) -> Artifact:
Expand All @@ -741,17 +754,22 @@ def from_flyte_idl(cls, pb2: artifacts_pb2.Artifact) -> Artifact:
if len(pb2.artifact_id.partitions.value) > 0:
# static values should be the only ones set since currently we don't from_flyte_idl
# anything that's not a materialized artifact.
if TIME_PARTITION in pb2.artifact_id.partitions.value:
a._time_partition = TimePartition(pb2.artifact_id.partitions.value[TIME_PARTITION].static_value)
a._time_partition.reference_artifact = a
# if TIME_PARTITION in pb2.artifact_id.partitions.value:
# a._time_partition = TimePartition(pb2.artifact_id.partitions.value[TIME_PARTITION].static_value)
# a._time_partition.reference_artifact = a

a._partitions = Partitions(
partitions={
k: Partition(value=v, name=k)
for k, v in pb2.artifact_id.partitions.value.items()
if k != TIME_PARTITION
# if k != TIME_PARTITION
}
)
a.partitions.reference_artifact = a
if pb2.artifact_id.HasField("time_partition"):
ts = pb2.artifact_id.time_partition.value.time_value
dt = ts.ToDatetime()
a._time_partition = TimePartition(dt)
a._time_partition.reference_artifact = a

return a
4 changes: 2 additions & 2 deletions flytekit/models/literals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime as _datetime
from datetime import timezone as _timezone
from typing import Optional, Dict
from typing import Dict, Optional

from flyteidl.core import literals_pb2 as _literals_pb2
from google.protobuf.struct_pb2 import Struct
Expand Down Expand Up @@ -859,7 +859,7 @@ def __init__(
collection: Optional[LiteralCollection] = None,
map: Optional[LiteralMap] = None,
hash: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, str]] = None,
):
"""
This IDL message represents a literal value in the Flyte ecosystem.
Expand Down
4 changes: 2 additions & 2 deletions flytekit/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from flyteidl.core import identifier_pb2 as idl
from flyteidl.core import interface_pb2

from flytekit.core.artifact import TIME_PARTITION, Artifact, ArtifactQuery, Partition, TimePartition
from flytekit.core.artifact import Artifact, ArtifactQuery, Partition, TimePartition
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.tracker import TrackedInstance
Expand Down Expand Up @@ -77,7 +77,7 @@ def get_parameter_map(
expr = None
if v.op and v.other and isinstance(v.other, timedelta):
expr = str(v.op) + isodate.duration_isoformat(v.other)
aq = v.reference_artifact.embed_as_query(self.triggers, TIME_PARTITION, expr)
aq = v.reference_artifact.embed_as_query(self.triggers, bind_to_time_partition=True, expr=expr)
p = interface_pb2.Parameter(var=var, artifact_query=aq)
elif isinstance(v, Partition):
# The reason is that if we bind to arbitrary partitions, we'll have to start keeping track of types
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from setuptools import setup

setup()
setup()
Loading

0 comments on commit 692b78c

Please sign in to comment.