diff --git a/.github/workflows/build_image.yml b/.github/workflows/build_image.yml index c2d0e97bcb..e3b894a4e2 100644 --- a/.github/workflows/build_image.yml +++ b/.github/workflows/build_image.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v4 with: diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 8ec1c1c0a7..d1de1bee2a 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -15,6 +15,41 @@ concurrency: jobs: build: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.8", "3.11", "3.12"] + steps: + - uses: insightsengineering/disk-space-reclaimer@v1 + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Cache pip + uses: actions/cache@v3 + with: + # This path is specific to Ubuntu + path: ~/.cache/pip + # Look to see if there is a cache hit for the corresponding requirements files + key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.in', 'requirements.in')) }} + - name: Install dependencies + run: | + make setup + pip uninstall -y pandas + pip freeze + - name: Test with coverage + run: | + make unit_test_codecov + - name: Codecov + uses: codecov/codecov-action@v3.1.4 + with: + fail_ci_if_error: false + files: coverage.xml + + build-with-extras: runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -46,9 +81,6 @@ jobs: if: ${{ matrix.python-version != '3.12' }} run: | make unit_test_extras_codecov - - name: Test with coverage - run: | - make unit_test_codecov - name: Codecov uses: codecov/codecov-action@v3.1.4 with: @@ -207,6 +239,7 @@ jobs: # onnx-tensorflow needs a version of tensorflow that does not work with protobuf>4. # The issue is being tracked on the tensorflow side in https://github.com/tensorflow/tensorflow/issues/53234#issuecomment-1330111693 # flytekit-onnx-tensorflow + - flytekit-openai - flytekit-pandera - flytekit-papermill - flytekit-polars @@ -250,6 +283,7 @@ jobs: plugin-names: "flytekit-whylogs" steps: - uses: insightsengineering/disk-space-reclaimer@v1 + if: ${{ matrix.plugin-names == 'flytekit-envd' }} - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 @@ -264,11 +298,12 @@ jobs: key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.txt', format('plugins/{0}/requirements.txt', matrix.plugin-names ))) }} - name: Install dependencies run: | + export SETUPTOOLS_SCM_PRETEND_VERSION="2.0.0" make setup cd plugins/${{ matrix.plugin-names }} - pip install --pre . - if [ -f dev-requirements.txt ]; then pip install -r dev-requirements.txt; fi - pip install --pre -U $GITHUB_WORKSPACE + pip install . + if [ -f dev-requirements.in ]; then pip install -r dev-requirements.in; fi + pip install -U $GITHUB_WORKSPACE pip freeze - name: Test with coverage run: | diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index d82fcfb4d9..35b098524a 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -24,7 +24,7 @@ jobs: run: | # from refs/tags/v1.2.3 get 1.2.3 VERSION=$(echo $GITHUB_REF | sed 's#.*/v##') - echo "::set-output name=version::$VERSION" + echo "version=$VERSION" >> $GITHUB_OUTPUT shell: bash - name: Build and publish env: @@ -167,6 +167,16 @@ jobs: registry: ghcr.io username: "${{ secrets.FLYTE_BOT_USERNAME }}" password: "${{ secrets.FLYTE_BOT_PAT }}" + - name: Prepare Flyte Agent Slim Image Names + id: flyteagent-slim-names + uses: docker/metadata-action@v3 + with: + images: | + ghcr.io/${{ github.repository_owner }}/flyteagent-slim + tags: | + latest + ${{ github.sha }} + ${{ needs.deploy.outputs.version }} - name: Prepare Flyte Agent Image Names id: flyteagent-names uses: docker/metadata-action@v3 @@ -177,11 +187,25 @@ jobs: latest ${{ github.sha }} ${{ needs.deploy.outputs.version }} - - name: Push External Plugin Service Image to GitHub Registry + - name: Push flyteagent-slim Image to GitHub Registry + uses: docker/build-push-action@v2 + with: + context: "." + platforms: linux/arm64, linux/amd64 + target: agent-slim + push: ${{ github.event_name == 'release' }} + tags: ${{ steps.flyteagent-slim-names.outputs.tags }} + build-args: | + VERSION=${{ needs.deploy.outputs.version }} + file: ./Dockerfile.agent + cache-from: type=gha + cache-to: type=gha,mode=max + - name: Push flyteagent-all Image to GitHub Registry uses: docker/build-push-action@v2 with: context: "." platforms: linux/arm64, linux/amd64 + target: agent-all push: ${{ github.event_name == 'release' }} tags: ${{ steps.flyteagent-names.outputs.tags }} build-args: | diff --git a/Dockerfile.agent b/Dockerfile.agent index fe4ce56290..fbd8c4c4a8 100644 --- a/Dockerfile.agent +++ b/Dockerfile.agent @@ -1,4 +1,4 @@ -FROM python:3.9-slim-bookworm +FROM python:3.9-slim-bookworm as agent-slim MAINTAINER Flyte Team LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit @@ -7,12 +7,11 @@ ARG VERSION RUN apt-get update && apt-get install build-essential -y -RUN pip install prometheus-client +RUN pip install prometheus-client grpcio-health-checking RUN pip install --no-cache-dir -U flytekit==$VERSION \ - flytekitplugins-bigquery==$VERSION \ flytekitplugins-airflow==$VERSION \ - flytekitplugins-mmcloud==$VERSION \ - flytekitplugins-spark==$VERSION \ + flytekitplugins-bigquery==$VERSION \ + flytekitplugins-chatgpt==$VERSION \ flytekitplugins-snowflake==$VERSION \ && apt-get clean autoclean \ && apt-get autoremove --yes \ @@ -20,3 +19,10 @@ RUN pip install --no-cache-dir -U flytekit==$VERSION \ && : CMD pyflyte serve agent --port 8000 + +FROM agent-slim as agent-all +ARG VERSION + +RUN pip install --no-cache-dir -U \ + flytekitplugins-mmcloud==$VERSION \ + flytekitplugins-spark==$VERSION diff --git a/Dockerfile.dev b/Dockerfile.dev index 2b85a5f7d2..63019e5e38 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -28,8 +28,8 @@ COPY . /flytekit # 3. Clean up the apt cache to reduce image size. Reference: https://gist.github.com/marvell/7c812736565928e602c4 # 4. Create a non-root user 'flytekit' and set appropriate permissions for directories. RUN apt-get update && apt-get install build-essential vim libmagic1 git -y \ + && pip install "git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl" \ && pip install --no-cache-dir -U --pre \ - flyteidl \ -e /flytekit \ -e /flytekit/plugins/flytekit-k8s-pod \ -e /flytekit/plugins/flytekit-deck-standard \ @@ -43,6 +43,7 @@ RUN apt-get update && apt-get install build-essential vim libmagic1 git -y \ && chown flytekit: /home \ && : + ENV PYTHONPATH "/flytekit:/flytekit/plugins/flytekit-k8s-pod:/flytekit/plugins/flytekit-deck-standard:" # Switch to the 'flytekit' user for better security. diff --git a/Makefile b/Makefile index 9ec2f6ba80..fa245479dd 100644 --- a/Makefile +++ b/Makefile @@ -25,8 +25,7 @@ update_boilerplate: .PHONY: setup setup: install-piptools ## Install requirements - pip install --pre -r dev-requirements.in - + pip install -r dev-requirements.in .PHONY: fmt fmt: @@ -63,14 +62,14 @@ unit_test_extras_codecov: unit_test: # Skip all extra tests and run them with the necessary env var set so that a working (albeit slower) # library is used to serialize/deserialize protobufs is used. - $(PYTEST_AND_OPTS) -m "not (serial or sandbox_test)" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/ --ignore=tests/flytekit/unit/models ${CODECOV_OPTS} + $(PYTEST_AND_OPTS) -m "not (serial or sandbox_test)" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/ --ignore=tests/flytekit/unit/models --ignore=tests/flytekit/unit/extend ${CODECOV_OPTS} # Run serial tests without any parallelism - $(PYTEST) -m "serial" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/ --ignore=tests/flytekit/unit/models ${CODECOV_OPTS} + $(PYTEST) -m "serial" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/ --ignore=tests/flytekit/unit/models --ignore=tests/flytekit/unit/extend ${CODECOV_OPTS} .PHONY: unit_test_extras unit_test_extras: - PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python $(PYTEST_AND_OPTS) tests/flytekit/unit/extras ${CODECOV_OPTS} + PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python $(PYTEST_AND_OPTS) tests/flytekit/unit/extras tests/flytekit/unit/extend ${CODECOV_OPTS} .PHONY: test_serialization_codecov test_serialization_codecov: diff --git a/dev-requirements.in b/dev-requirements.in index d9784f75d0..d866cfc1c8 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -1,4 +1,5 @@ -e file:.#egg=flytekit +git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl coverage[toml] hypothesis diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index bd87f7b64b..e2887ce51a 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -595,7 +595,7 @@ def _get_params( defaults: typing.Optional[typing.Dict[str, Parameter]] = None, ) -> typing.List["click.Parameter"]: params = [] - flyte_ctx = context_manager.FlyteContextManager.current_context() + flyte_ctx = ctx.obj.remote_instance().context for name, var in inputs.items(): if fixed and name in fixed: continue @@ -609,6 +609,7 @@ def _get_params( return params def get_params(self, ctx: click.Context) -> typing.List["click.Parameter"]: + ctx.obj.remote = True if not self.params: self.params = [] entity = self._fetch_entity(ctx) diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py index 87f008b084..6a7e5c3c28 100644 --- a/flytekit/clis/sdk_in_container/serve.py +++ b/flytekit/clis/sdk_in_container/serve.py @@ -1,11 +1,13 @@ from concurrent import futures +import grpc import rich_click as click +from flyteidl.service import agent_pb2 from flyteidl.service.agent_pb2_grpc import ( add_AgentMetadataServiceServicer_to_server, add_AsyncAgentServiceServicer_to_server, + add_SyncAgentServiceServicer_to_server, ) -from grpc import aio @click.group("serve") @@ -51,21 +53,46 @@ def agent(_: click.Context, port, worker, timeout): async def _start_grpc_server(port: int, worker: int, timeout: int): - click.secho("Starting up the server to expose the prometheus metrics...", fg="blue") - from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService + from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService, SyncAgentService - try: - from prometheus_client import start_http_server - - start_http_server(9090) - except ImportError as e: - click.secho(f"Failed to start the prometheus server with error {e}", fg="red") + _start_http_server() click.secho("Starting the agent service...", fg="blue") - server = aio.server(futures.ThreadPoolExecutor(max_workers=worker)) + server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=worker)) add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server) + add_SyncAgentServiceServicer_to_server(SyncAgentService(), server) add_AgentMetadataServiceServicer_to_server(AgentMetadataService(), server) + _start_health_check_server(server, worker) server.add_insecure_port(f"[::]:{port}") await server.start() await server.wait_for_termination(timeout) + + +def _start_http_server(): + try: + from prometheus_client import start_http_server + + click.secho("Starting up the server to expose the prometheus metrics...", fg="blue") + start_http_server(9090) + except ImportError as e: + click.secho(f"Failed to start the prometheus server with error {e}", fg="red") + + +def _start_health_check_server(server: grpc.Server, worker: int): + try: + from grpc_health.v1 import health, health_pb2, health_pb2_grpc + + health_servicer = health.HealthServicer( + experimental_non_blocking=True, + experimental_thread_pool=futures.ThreadPoolExecutor(max_workers=worker), + ) + + for service in agent_pb2.DESCRIPTOR.services_by_name.values(): + health_servicer.set(service.full_name, health_pb2.HealthCheckResponse.SERVING) + health_servicer.set(health.SERVICE_NAME, health_pb2.HealthCheckResponse.SERVING) + + health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server) + + except ImportError as e: + click.secho(f"Failed to start the health check servicer with error {e}", fg="red") diff --git a/flytekit/core/artifact.py b/flytekit/core/artifact.py index 6c709f59a1..27d16b4822 100644 --- a/flytekit/core/artifact.py +++ b/flytekit/core/artifact.py @@ -147,7 +147,13 @@ def __init__( self.time_partition = time_partition self.partitions = partitions self.tag = tag - self.bindings = bindings + if len(bindings) > 0: + b = set(bindings) + if len(b) > 1: + raise ValueError(f"Multiple bindings found in query {self}") + self.binding: Optional[Artifact] = bindings[0] + else: + self.binding = None def to_flyte_idl( self, @@ -391,23 +397,19 @@ def concrete_artifact_id(self) -> art_id.ArtifactID: def embed_as_query( self, - bindings: typing.List[Artifact], partition: Optional[str] = None, bind_to_time_partition: Optional[bool] = None, expr: Optional[str] = None, ) -> art_id.ArtifactQuery: """ This should only be called in the context of a Trigger - :param bindings: The list of artifacts in trigger_on :param partition: Can embed a time partition :param bind_to_time_partition: Set to true if you want to bind to a time partition :param expr: Only valid if there's a time partition. """ # Find self in the list, raises ValueError if not there. - idx = bindings.index(self) aq = art_id.ArtifactQuery( binding=art_id.ArtifactBindingData( - index=idx, partition_key=partition, bind_to_time_partition=bind_to_time_partition, transform=str(expr) if expr and (partition or bind_to_time_partition) else None, diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 7f45287428..0b097ad847 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -8,6 +8,7 @@ from flytekit.core.interface import Interface, transform_function_to_interface, transform_inputs_to_parameters from flytekit.core.promise import create_and_link_node, translate_inputs_to_literals from flytekit.core.reference_entity import LaunchPlanReference, ReferenceEntity +from flytekit.core.schedule import LaunchPlanTriggerBase from flytekit.models import common as _common_models from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models @@ -123,6 +124,7 @@ def create( max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, auth_role: Optional[_common_models.AuthRole] = None, + trigger: Optional[LaunchPlanTriggerBase] = None, overwrite_cache: Optional[bool] = None, ) -> LaunchPlan: ctx = FlyteContextManager.current_context() @@ -174,6 +176,7 @@ def create( raw_output_data_config=raw_output_data_config, max_parallelism=max_parallelism, security_context=security_context, + trigger=trigger, overwrite_cache=overwrite_cache, ) @@ -203,6 +206,7 @@ def get_or_create( max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, auth_role: Optional[_common_models.AuthRole] = None, + trigger: Optional[LaunchPlanTriggerBase] = None, overwrite_cache: Optional[bool] = None, ) -> LaunchPlan: """ @@ -229,6 +233,7 @@ def get_or_create( :param max_parallelism: Controls the maximum number of tasknodes that can be run in parallel for the entire workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and parallelism/concurrency of MapTasks is independent from this. + :param trigger: [alpha] This is a new syntax for specifying schedules. """ if name is None and ( default_inputs is not None @@ -241,6 +246,7 @@ def get_or_create( or auth_role is not None or max_parallelism is not None or security_context is not None + or trigger is not None or overwrite_cache is not None ): raise ValueError( @@ -299,6 +305,7 @@ def get_or_create( max_parallelism, auth_role=auth_role, security_context=security_context, + trigger=trigger, overwrite_cache=overwrite_cache, ) LaunchPlan.CACHE[name or workflow.name] = lp @@ -317,8 +324,8 @@ def __init__( raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, + trigger: Optional[LaunchPlanTriggerBase] = None, overwrite_cache: Optional[bool] = None, - additional_metadata: Optional[Any] = None, ): self._name = name self._workflow = workflow @@ -336,8 +343,8 @@ def __init__( self._raw_output_data_config = raw_output_data_config self._max_parallelism = max_parallelism self._security_context = security_context + self._trigger = trigger self._overwrite_cache = overwrite_cache - self._additional_metadata = additional_metadata FlyteEntities.entities.append(self) @@ -353,6 +360,7 @@ def clone_with( raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, + trigger: Optional[LaunchPlanTriggerBase] = None, overwrite_cache: Optional[bool] = None, ) -> LaunchPlan: return LaunchPlan( @@ -367,6 +375,7 @@ def clone_with( raw_output_data_config=raw_output_data_config or self.raw_output_data_config, max_parallelism=max_parallelism or self.max_parallelism, security_context=security_context or self.security_context, + trigger=trigger, overwrite_cache=overwrite_cache or self.overwrite_cache, ) @@ -435,8 +444,8 @@ def security_context(self) -> Optional[security.SecurityContext]: return self._security_context @property - def additional_metadata(self) -> Optional[Any]: - return self._additional_metadata + def trigger(self) -> Optional[LaunchPlanTriggerBase]: + return self._trigger def construct_node_metadata(self) -> _workflow_model.NodeMetadata: return self.workflow.construct_node_metadata() diff --git a/flytekit/core/node.py b/flytekit/core/node.py index f5a3db4afa..e8a37bf3f0 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -65,6 +65,7 @@ def __init__( self._outputs = None self._resources: typing.Optional[_resources_model] = None self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None + self._container_image: typing.Optional[str] = None def runs_before(self, other: Node): """ @@ -193,7 +194,7 @@ def with_overrides(self, *args, **kwargs): if "container_image" in kwargs: v = kwargs["container_image"] assert_not_promise(v, "container_image") - self.run_entity._container_image = v + self._container_image = v if "accelerator" in kwargs: v = kwargs["accelerator"] diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py index 93116d0720..5ce0948cfd 100644 --- a/flytekit/core/schedule.py +++ b/flytekit/core/schedule.py @@ -6,13 +6,20 @@ import datetime import re as _re -from typing import Optional +from typing import Optional, Protocol, Union import croniter as _croniter +from flyteidl.admin import schedule_pb2 +from google.protobuf import message as google_message from flytekit.models import schedule as _schedule_models +class LaunchPlanTriggerBase(Protocol): + def to_flyte_idl(self, *args, **kwargs) -> google_message.Message: + ... + + # Duplicates flytekit.common.schedules.Schedule to avoid using the ExtendedSdkType metaclass. class CronSchedule(_schedule_models.Schedule): """ @@ -202,3 +209,14 @@ def _translate_duration(duration: datetime.timedelta): int(duration.total_seconds() / _SECONDS_TO_MINUTES), _schedule_models.Schedule.FixedRateUnit.MINUTE, ) + + +class OnSchedule(LaunchPlanTriggerBase): + def __init__(self, schedule: Union[CronSchedule, FixedRate]): + """ + :param Union[CronSchedule, FixedRate] schedule: Either a cron or a fixed rate + """ + self._schedule = schedule + + def to_flyte_idl(self) -> schedule_pb2.Schedule: + return self._schedule.to_flyte_idl() diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 76f750233b..9153fca032 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -16,6 +16,7 @@ from typing import Dict, List, NamedTuple, Optional, Type, cast from dataclasses_json import DataClassJsonMixin, dataclass_json +from flyteidl.core import literals_pb2 from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct from google.protobuf.json_format import MessageToDict as _MessageToDict @@ -1164,19 +1165,34 @@ def named_tuple_to_variable_map(cls, t: typing.NamedTuple) -> _interface_models. @classmethod @timeit("Translate literal to python value") def literal_map_to_kwargs( - cls, ctx: FlyteContext, lm: LiteralMap, python_types: typing.Dict[str, type] + cls, + ctx: FlyteContext, + lm: LiteralMap, + python_types: typing.Optional[typing.Dict[str, type]] = None, + literal_types: typing.Optional[typing.Dict[str, _interface_models.Variable]] = None, ) -> typing.Dict[str, typing.Any]: """ Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task """ - if len(lm.literals) > len(python_types): + if python_types is None and literal_types is None: + raise ValueError("At least one of python_types or literal_types must be provided") + + if literal_types: + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in literal_types.items() + } + else: + python_interface_inputs = python_types # type: ignore + + if len(lm.literals) > len(python_interface_inputs): raise ValueError( - f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}" + f"Received more input values {len(lm.literals)}" + f" than allowed by the input spec {len(python_interface_inputs)}" ) kwargs = {} for i, k in enumerate(lm.literals): try: - kwargs[k] = TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) + kwargs[k] = TypeEngine.to_python_value(ctx, lm.literals[k], python_interface_inputs[k]) except TypeTransformerFailedError as exc: raise TypeTransformerFailedError(f"Error converting input '{k}' at position {i}:\n {exc}") from exc return kwargs @@ -1210,6 +1226,16 @@ def dict_to_literal_map( raise user_exceptions.FlyteTypeException(type(v), python_type, received_value=v) return LiteralMap(literal_map) + @classmethod + def dict_to_literal_map_pb( + cls, + ctx: FlyteContext, + d: typing.Dict[str, typing.Any], + type_hints: Optional[typing.Dict[str, type]] = None, + ) -> Optional[literals_pb2.LiteralMap]: + literal_map = cls.dict_to_literal_map(ctx, d, type_hints) + return literal_map.to_flyte_idl() + @classmethod def get_available_transformers(cls) -> typing.KeysView[Type]: """ diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 2d4246c6c1..c000b92150 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -1,4 +1,5 @@ import typing +from http import HTTPStatus import grpc from flyteidl.admin.agent_pb2 import ( @@ -6,19 +7,28 @@ CreateTaskResponse, DeleteTaskRequest, DeleteTaskResponse, + ExecuteTaskSyncRequest, + ExecuteTaskSyncResponse, + ExecuteTaskSyncResponseHeader, GetAgentRequest, GetAgentResponse, GetTaskRequest, GetTaskResponse, ListAgentsRequest, ListAgentsResponse, + Resource, +) +from flyteidl.service.agent_pb2_grpc import ( + AgentMetadataServiceServicer, + AsyncAgentServiceServicer, + SyncAgentServiceServicer, ) -from flyteidl.service.agent_pb2_grpc import AgentMetadataServiceServicer, AsyncAgentServiceServicer from prometheus_client import Counter, Summary -from flytekit import logger +from flytekit import FlyteContext, logger +from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.system import FlyteAgentNotFound -from flytekit.extend.backend.base_agent import AgentRegistry, mirror_async_methods +from flytekit.extend.backend.base_agent import AgentRegistry, SyncAgentBase, mirror_async_methods from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -26,6 +36,7 @@ create_operation = "create" get_operation = "get" delete_operation = "delete" +do_operation = "do" # Follow the naming convention. https://prometheus.io/docs/practices/naming/ request_success_count = Counter( @@ -46,7 +57,24 @@ input_literal_size = Summary(f"{metric_prefix}input_literal_bytes", "Size of input literal", ["task_type"]) -def agent_exception_handler(func: typing.Callable): +def _handle_exception(e: Exception, context: grpc.ServicerContext, task_type: str, operation: str): + if isinstance(e, FlyteAgentNotFound): + error_message = f"Cannot find agent for task type: {task_type}." + logger.error(error_message) + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(error_message) + request_failure_count.labels(task_type=task_type, operation=operation, error_code=HTTPStatus.NOT_FOUND).inc() + else: + error_message = f"failed to {operation} {task_type} task with error: {e}." + logger.error(error_message) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(error_message) + request_failure_count.labels( + task_type=task_type, operation=operation, error_code=HTTPStatus.INTERNAL_SERVER_ERROR + ).inc() + + +def record_agent_metrics(func: typing.Callable): async def wrapper( self, request: typing.Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest], @@ -60,10 +88,10 @@ async def wrapper( if request.inputs: input_literal_size.labels(task_type=task_type).observe(request.inputs.ByteSize()) elif isinstance(request, GetTaskRequest): - task_type = request.task_type + task_type = request.task_type or request.task_category.name operation = get_operation elif isinstance(request, DeleteTaskRequest): - task_type = request.task_type + task_type = request.task_type or request.task_category.name operation = delete_operation else: context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -75,51 +103,90 @@ async def wrapper( res = await func(self, request, context, *args, **kwargs) request_success_count.labels(task_type=task_type, operation=operation).inc() return res - except FlyteAgentNotFound: - error_message = f"Cannot find agent for task type: {task_type}." - logger.error(error_message) - context.set_code(grpc.StatusCode.NOT_FOUND) - context.set_details(error_message) - request_failure_count.labels(task_type=task_type, operation=operation, error_code="404").inc() except Exception as e: - error_message = f"failed to {operation} {task_type} task with error {e}." - logger.error(error_message) - context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(error_message) - request_failure_count.labels(task_type=task_type, operation=operation, error_code="500").inc() + _handle_exception(e, context, task_type, operation) return wrapper class AsyncAgentService(AsyncAgentServiceServicer): - @agent_exception_handler + @record_agent_metrics async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse: - tmp = TaskTemplate.from_flyte_idl(request.template) + template = TaskTemplate.from_flyte_idl(request.template) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - agent = AgentRegistry.get_agent(tmp.type) + agent = AgentRegistry.get_agent(template.type, template.task_type_version) - logger.info(f"{tmp.type} agent start creating the job") - return await mirror_async_methods( - agent.create, output_prefix=request.output_prefix, task_template=tmp, inputs=inputs - ) + logger.info(f"{agent.name} start creating the job") + resource_mata = await mirror_async_methods(agent.create, task_template=template, inputs=inputs) + return CreateTaskResponse(resource_meta=resource_mata.encode()) - @agent_exception_handler + @record_agent_metrics async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse: - agent = AgentRegistry.get_agent(request.task_type) - logger.info(f"{agent.task_type} agent start checking the status of the job") - return await mirror_async_methods(agent.get, resource_meta=request.resource_meta) + if request.task_category and request.task_category.name: + agent = AgentRegistry.get_agent(request.task_category.name, request.task_category.version) + else: + agent = AgentRegistry.get_agent(request.task_type) + logger.info(f"{agent.name} start checking the status of the job") + res = await mirror_async_methods(agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta)) + + if res.outputs is None: + outputs = None + elif isinstance(res.outputs, LiteralMap): + outputs = res.outputs.to_flyte_idl() + else: + ctx = FlyteContext.current_context() + outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) + return GetTaskResponse( + resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs) + ) - @agent_exception_handler + @record_agent_metrics async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse: - agent = AgentRegistry.get_agent(request.task_type) - logger.info(f"{agent.task_type} agent start deleting the job") - return await mirror_async_methods(agent.delete, resource_meta=request.resource_meta) + if request.task_category and request.task_category.name: + agent = AgentRegistry.get_agent(request.task_category.name, request.task_category.version) + else: + agent = AgentRegistry.get_agent(request.task_type) + logger.info(f"{agent.name} start deleting the job") + return await mirror_async_methods(agent.delete, resource_meta=agent.metadata_type.decode(request.resource_meta)) + + +class SyncAgentService(SyncAgentServiceServicer): + async def ExecuteTaskSync( + self, request_iterator: typing.AsyncIterator[ExecuteTaskSyncRequest], context: grpc.ServicerContext + ) -> typing.AsyncIterator[ExecuteTaskSyncResponse]: + request = await request_iterator.__anext__() + template = TaskTemplate.from_flyte_idl(request.header.template) + task_type = template.type + try: + with request_latency.labels(task_type=task_type, operation=do_operation).time(): + agent = AgentRegistry.get_agent(task_type, template.task_type_version) + if not isinstance(agent, SyncAgentBase): + raise ValueError(f"[{agent.name}] does not support sync execution") + + request = await request_iterator.__anext__() + literal_map = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None + res = await mirror_async_methods(agent.do, task_template=template, inputs=literal_map) + + if res.outputs is None: + outputs = None + elif isinstance(res.outputs, LiteralMap): + outputs = res.outputs.to_flyte_idl() + else: + ctx = FlyteContext.current_context() + outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) + + header = ExecuteTaskSyncResponseHeader( + resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs) + ) + yield ExecuteTaskSyncResponse(header=header) + request_success_count.labels(task_type=task_type, operation=do_operation).inc() + except Exception as e: + _handle_exception(e, context, template.type, do_operation) class AgentMetadataService(AgentMetadataServiceServicer): async def GetAgent(self, request: GetAgentRequest, context: grpc.ServicerContext) -> GetAgentResponse: - return GetAgentResponse(agent=AgentRegistry._METADATA[request.name]) + return GetAgentResponse(agent=AgentRegistry.get_agent_metadata(request.name)) async def ListAgents(self, request: ListAgentsRequest, context: grpc.ServicerContext) -> ListAgentsResponse: - agents = [agent for agent in AgentRegistry._METADATA.values()] - return ListAgentsResponse(agents=agents) + return ListAgentsResponse(agents=AgentRegistry.list_agents()) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 5a6e5cd3bf..7cbf380b16 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -1,72 +1,161 @@ import asyncio -import inspect +import json import signal import sys import time import typing -from abc import ABC +from abc import ABC, abstractmethod from collections import OrderedDict +from dataclasses import asdict, dataclass from functools import partial from types import FrameType, coroutine +from typing import Any, Dict, List, Optional, Union -from flyteidl.admin.agent_pb2 import ( - Agent, - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, -) +from flyteidl.admin.agent_pb2 import Agent +from flyteidl.admin.agent_pb2 import TaskCategory as _TaskCategory from flyteidl.core import literals_pb2 -from flyteidl.core.execution_pb2 import TaskExecution -from flyteidl.core.tasks_pb2 import TaskTemplate +from flyteidl.core.execution_pb2 import TaskExecution, TaskLog from rich.progress import Progress -import flytekit from flytekit import FlyteContext, PythonFunctionTask, logger from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core import utils from flytekit.core.base_task import PythonTask -from flytekit.core.type_engine import TypeEngine +from flytekit.core.type_engine import TypeEngine, dataclass_from_dict from flytekit.exceptions.system import FlyteAgentNotFound from flytekit.exceptions.user import FlyteUserException +from flytekit.extend.backend.utils import is_terminal_phase, mirror_async_methods, render_task_template from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + + +class TaskCategory: + def __init__(self, name: str, version: int = 0): + self._name = name + self._version = version + + def __hash__(self): + return hash((self._name, self._version)) + + def __eq__(self, other: "TaskCategory"): + return self._name == other._name and self._version == other._version + + @property + def name(self) -> str: + return self._name + + @property + def version(self) -> int: + return self._version + + def __str__(self): + return f"{self._name}_v{self._version}" + + +@dataclass +class ResourceMeta: + """ + This is the metadata for the job. For example, the id of the job. + """ + + def encode(self) -> bytes: + """ + Encode the resource meta to bytes. + """ + return json.dumps(asdict(self)).encode("utf-8") + + @classmethod + def decode(cls, data: bytes) -> "ResourceMeta": + """ + Decode the resource meta from bytes. + """ + return dataclass_from_dict(cls, json.loads(data.decode("utf-8"))) + + +@dataclass +class Resource: + """ + This is the output resource of the job. + + Args: + phase: The phase of the job. + message: The return message from the job. + log_links: The log links of the job. For example, the link to the BigQuery Console. + outputs: The outputs of the job. If return python native types, the agent will convert them to flyte literals. + """ + + phase: TaskExecution.Phase + message: Optional[str] = None + log_links: Optional[List[TaskLog]] = None + outputs: Optional[Union[LiteralMap, typing.Dict[str, Any]]] = None + + +T = typing.TypeVar("T", bound=ResourceMeta) class AgentBase(ABC): + name = "Base Agent" + + def __init__(self, task_type_name: str, task_type_version: int = 0, **kwargs): + self._task_category = TaskCategory(name=task_type_name, version=task_type_version) + + @property + def task_category(self) -> TaskCategory: + """ + task category that the agent supports + """ + return self._task_category + + +class SyncAgentBase(AgentBase): + """ + This is the base class for all sync agents. It defines the interface that all agents must implement. + The agent service is responsible for invoking agents. + Propeller sends a request to agent service, and gets a response in the same call. + + All the agents should be registered in the AgentRegistry. Agent Service + will look up the agent based on the task type. Every task type can only have one agent. + """ + + name = "Base Sync Agent" + + @abstractmethod + def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> Resource: + """ + This is the method that the agent will run. + """ + raise NotImplementedError + + +class AsyncAgentBase(AgentBase, typing.Generic[T]): """ - This is the base class for all agents. It defines the interface that all agents must implement. - The agent service will be run either locally or in a pod, and will be responsible for - invoking agents. The propeller will communicate with the agent service + This is the base class for all async agents. It defines the interface that all agents must implement. + The agent service is responsible for invoking agents. The propeller will communicate with the agent service to create tasks, get the status of tasks, and delete tasks. All the agents should be registered in the AgentRegistry. Agent Service will look up the agent based on the task type. Every task type can only have one agent. """ - name = "Base Agent" + name = "Base Async Agent" - def __init__(self, task_type: str, **kwargs): - self._task_type = task_type + def __init__(self, metadata_type: typing.Type[T], **kwargs): + super().__init__(**kwargs) + self._metadata_type = metadata_type @property - def task_type(self) -> str: - """ - task_type is the name of the task type that this agent supports. - """ - return self._task_type - - def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: + def metadata_type(self) -> ResourceMeta: + return self._metadata_type + + @abstractmethod + def create(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> T: """ - Return a Unique ID for the task that was created. It should return error code if the task creation failed. + Return a resource meta that can be used to get the status of the task. """ raise NotImplementedError - def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: + @abstractmethod + def get(self, resource_meta: T, **kwargs) -> Resource: """ Return the status of the task, and return the outputs in some cases. For example, bigquery job can't write the structured dataset to the output location, so it returns the output literals to the propeller, @@ -74,9 +163,10 @@ def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: """ raise NotImplementedError - def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: + @abstractmethod + def delete(self, resource_meta: T, **kwargs): """ - Delete the task. This call should be idempotent. + Delete the task. This call should be idempotent. It should raise an error if fails to delete the task. """ raise NotImplementedError @@ -88,29 +178,42 @@ class AgentRegistry(object): The agent metadata service will look up the agent metadata based on the agent name. """ - _REGISTRY: typing.Dict[str, AgentBase] = {} - _METADATA: typing.Dict[str, Agent] = {} + _REGISTRY: Dict[TaskCategory, Union[AsyncAgentBase, SyncAgentBase]] = {} + _METADATA: Dict[str, Agent] = {} @staticmethod - def register(agent: AgentBase): - if agent.task_type in AgentRegistry._REGISTRY: - raise ValueError(f"Duplicate agent for task type {agent.task_type}") - AgentRegistry._REGISTRY[agent.task_type] = agent + def register(agent: Union[AsyncAgentBase, SyncAgentBase], override: bool = False): + if agent.task_category in AgentRegistry._REGISTRY and override is False: + raise ValueError(f"Duplicate agent for task type: {agent.task_category}") + AgentRegistry._REGISTRY[agent.task_category] = agent + + task_category = _TaskCategory(name=agent.task_category.name, version=agent.task_category.version) if agent.name in AgentRegistry._METADATA: - agent_metadata = AgentRegistry._METADATA[agent.name] - agent_metadata.supported_task_types.append(agent.task_type) + agent_metadata = AgentRegistry.get_agent_metadata(agent.name) + agent_metadata.supported_task_categories.append(task_category) + agent_metadata.supported_task_types.append(task_category.name) else: - agent_metadata = Agent(name=agent.name, supported_task_types=[agent.task_type]) + agent_metadata = Agent( + name=agent.name, + supported_task_types=[task_category.name], + supported_task_categories=[task_category], + is_sync=isinstance(agent, SyncAgentBase), + ) AgentRegistry._METADATA[agent.name] = agent_metadata - logger.info(f"Registering an agent for task type: {agent.task_type}, name: {agent.name}") + logger.info(f"Registering {agent.name} for task type: {agent.task_category}") @staticmethod - def get_agent(task_type: str) -> AgentBase: - if task_type not in AgentRegistry._REGISTRY: - raise FlyteAgentNotFound(f"Cannot find agent for task type: {task_type}.") - return AgentRegistry._REGISTRY[task_type] + def get_agent(task_type_name: str, task_type_version: int = 0) -> Union[SyncAgentBase, AsyncAgentBase]: + task_category = TaskCategory(name=task_type_name, version=task_type_version) + if task_category not in AgentRegistry._REGISTRY: + raise FlyteAgentNotFound(f"Cannot find agent for task category: {task_category}.") + return AgentRegistry._REGISTRY[task_category] + + @staticmethod + def list_agents() -> List[Agent]: + return list(AgentRegistry._METADATA.values()) @staticmethod def get_agent_metadata(name: str) -> Agent: @@ -119,96 +222,90 @@ def get_agent_metadata(name: str) -> Agent: return AgentRegistry._METADATA[name] -def mirror_async_methods(func: typing.Callable, **kwargs) -> typing.Coroutine: - if inspect.iscoroutinefunction(func): - return func(**kwargs) - args = [v for _, v in kwargs.items()] - return asyncio.get_running_loop().run_in_executor(None, func, *args) - - -def convert_to_flyte_phase(state: str) -> TaskExecution.Phase: - """ - Convert the state from the agent to the phase in flyte. +class SyncAgentExecutorMixin: """ - state = state.lower() - # timedout is the state of Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate - if state in ["failed", "timeout", "timedout", "canceled"]: - return TaskExecution.FAILED - elif state in ["done", "succeeded", "success"]: - return TaskExecution.SUCCEEDED - elif state in ["running"]: - return TaskExecution.RUNNING - raise ValueError(f"Unrecognized state: {state}") - - -def is_terminal_phase(phase: TaskExecution.Phase) -> bool: - """ - Return true if the phase is terminal. + This mixin class is used to run the sync task locally, and it's only used for local execution. + Task should inherit from this class if the task can be run in the agent. + + Synchronous tasks run quickly and can return their results instantly. + Sending a prompt to ChatGPT and getting a response, or retrieving some metadata from a backend system. """ - return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED] + T = typing.TypeVar("T", "SyncAgentExecutorMixin", PythonTask) + + def execute(self: T, **kwargs) -> LiteralMap: + from flytekit.tools.translator import get_serializable + + ctx = FlyteContext.current_context() + ss = ctx.serialization_settings or SerializationSettings(ImageConfig()) + task_template = get_serializable(OrderedDict(), ss, self).template + + agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) + + resource = asyncio.run(self._do(agent, task_template, kwargs)) + if resource.phase != TaskExecution.SUCCEEDED: + raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}") -def get_agent_secret(secret_key: str) -> str: - return flytekit.current_context().secrets.get(secret_key) + if resource.outputs and not isinstance(resource.outputs, LiteralMap): + return TypeEngine.dict_to_literal_map(ctx, resource.outputs) + return resource.outputs + + async def _do(self: T, agent: SyncAgentBase, template: TaskTemplate, inputs: Dict[str, Any] = None) -> Resource: + try: + ctx = FlyteContext.current_context() + literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) + return await mirror_async_methods(agent.do, task_template=template, inputs=literal_map) + except Exception as error_message: + raise FlyteUserException(f"Failed to run the task {self.name} with error: {error_message}") class AsyncAgentExecutorMixin: """ - This mixin class is used to run the agent task locally, and it's only used for local execution. + This mixin class is used to run the async task locally, and it's only used for local execution. Task should inherit from this class if the task can be run in the agent. - It can handle asynchronous tasks and synchronous tasks. + Asynchronous tasks are tasks that take a long time to complete, such as running a query. - Synchronous tasks run quickly and can return their results instantly. Sending a prompt to ChatGPT and getting a response, or retrieving some metadata from a backend system. """ + T = typing.TypeVar("T", "AsyncAgentExecutorMixin", PythonTask) + _clean_up_task: coroutine = None - _agent: AgentBase = None - _entity: PythonTask = None + _agent: AsyncAgentBase = None - def execute(self, **kwargs) -> typing.Any: + def execute(self: T, **kwargs) -> LiteralMap: ctx = FlyteContext.current_context() ss = ctx.serialization_settings or SerializationSettings(ImageConfig()) output_prefix = ctx.file_access.get_random_remote_directory() from flytekit.tools.translator import get_serializable - self._entity = typing.cast(PythonTask, self) - task_template = get_serializable(OrderedDict(), ss, self._entity).template - self._agent = AgentRegistry.get_agent(task_template.type) - - res = asyncio.run(self._create(task_template, output_prefix, kwargs)) - - # If the task is synchronous, the agent will return the output from the resource literals. - if res.HasField("resource"): - if res.resource.phase != TaskExecution.SUCCEEDED: - raise FlyteUserException(f"Failed to run the task {self._entity.name}") - return LiteralMap.from_flyte_idl(res.resource.outputs) + task_template = get_serializable(OrderedDict(), ss, self).template + self._agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) - res = asyncio.run(self._get(resource_meta=res.resource_meta)) + resource_mata = asyncio.run(self._create(task_template, output_prefix, kwargs)) + resource = asyncio.run(self._get(resource_meta=resource_mata)) - if res.resource.phase != TaskExecution.SUCCEEDED: - raise FlyteUserException(f"Failed to run the task {self._entity.name}") + if resource.phase != TaskExecution.SUCCEEDED: + raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}") - # Read the literals from a remote file, if agent doesn't return the output literals. - if task_template.interface.outputs and len(res.resource.outputs.literals) == 0: + # Read the literals from a remote file if the agent doesn't return the output literals. + if task_template.interface.outputs and resource.outputs is None: local_outputs_file = ctx.file_access.get_random_local_path() - ctx.file_access.get_data(f"{output_prefix}/output/outputs.pb", local_outputs_file) + ctx.file_access.get_data(f"{output_prefix}/outputs.pb", local_outputs_file) output_proto = utils.load_proto_from_file(literals_pb2.LiteralMap, local_outputs_file) return LiteralMap.from_flyte_idl(output_proto) - return LiteralMap.from_flyte_idl(res.resource.outputs) + if resource.outputs and not isinstance(resource.outputs, LiteralMap): + return TypeEngine.dict_to_literal_map(ctx, resource.outputs) + + return resource.outputs async def _create( - self, task_template: TaskTemplate, output_prefix: str, inputs: typing.Dict[str, typing.Any] = None - ) -> CreateTaskResponse: + self: T, task_template: TaskTemplate, output_prefix: str, inputs: Dict[str, Any] = None + ) -> ResourceMeta: ctx = FlyteContext.current_context() - # Convert python inputs to literals - literals = inputs or {} - for k, v in inputs.items(): - literals[k] = TypeEngine.to_literal(ctx, v, type(v), self._entity.interface.inputs[k].type) - literal_map = LiteralMap(literals) - + literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) if isinstance(self, PythonFunctionTask): # Write the inputs to a remote file, so that the remote task can read the inputs from this file. path = ctx.file_access.get_random_local_path() @@ -216,58 +313,47 @@ async def _create( ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb") task_template = render_task_template(task_template, output_prefix) - res = await mirror_async_methods( + resource_meta = await mirror_async_methods( self._agent.create, - output_prefix=output_prefix, task_template=task_template, inputs=literal_map, ) - signal.signal(signal.SIGINT, partial(self.signal_handler, res.resource_meta)) # type: ignore - return res + signal.signal(signal.SIGINT, partial(self.signal_handler, resource_meta)) # type: ignore + return resource_meta - async def _get(self, resource_meta: bytes) -> GetTaskResponse: + async def _get(self: T, resource_meta: ResourceMeta) -> Resource: phase = TaskExecution.RUNNING progress = Progress(transient=True) - task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None) + task = progress.add_task(f"[cyan]Running Task {self.name}...", total=None) task_phase = progress.add_task("[cyan]Task phase: RUNNING, Phase message: ", total=None, visible=False) task_log_links = progress.add_task("[cyan]Log Links: ", total=None, visible=False) with progress: while not is_terminal_phase(phase): progress.start_task(task) time.sleep(1) - res = await mirror_async_methods(self._agent.get, resource_meta=resource_meta) + resource = await mirror_async_methods(self._agent.get, resource_meta=resource_meta) if self._clean_up_task: await self._clean_up_task sys.exit(1) - phase = res.resource.phase + phase = resource.phase progress.update( task_phase, - description=f"[cyan]Task phase: {TaskExecution.Phase.Name(phase)}, Phase message: {res.resource.message}", + description=f"[cyan]Task phase: {TaskExecution.Phase.Name(phase)}, Phase message: {resource.message}", visible=True, ) - log_links = "" - for link in res.log_links: - log_links += f"{link.name}: {link.uri}\n" - if log_links: - progress.update(task_log_links, description=f"[cyan]{log_links}", visible=True) + if resource.log_links: + log_links = "" + for link in resource.log_links: + log_links += f"{link.name}: {link.uri}\n" + if log_links: + progress.update(task_log_links, description=f"[cyan]{log_links}", visible=True) - return res + return resource - def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any: + def signal_handler(self, resource_meta: ResourceMeta, signum: int, frame: FrameType) -> Any: if self._clean_up_task is None: co = mirror_async_methods(self._agent.delete, resource_meta=resource_meta) self._clean_up_task = asyncio.create_task(co) - - -def render_task_template(tt: TaskTemplate, file_prefix: str) -> TaskTemplate: - args = tt.container.args - for i in range(len(args)): - tt.container.args[i] = args[i].replace("{{.input}}", f"{file_prefix}/inputs.pb") - tt.container.args[i] = args[i].replace("{{.outputPrefix}}", f"{file_prefix}/output") - tt.container.args[i] = args[i].replace("{{.rawOutputDataPrefix}}", f"{file_prefix}/raw_output") - tt.container.args[i] = args[i].replace("{{.checkpointOutputPrefix}}", f"{file_prefix}/checkpoint_output") - tt.container.args[i] = args[i].replace("{{.prevCheckpointPrefix}}", f"{file_prefix}/prev_checkpoint") - return tt diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py new file mode 100644 index 0000000000..b20c9fdf66 --- /dev/null +++ b/flytekit/extend/backend/utils.py @@ -0,0 +1,52 @@ +import asyncio +import inspect +from typing import Callable, Coroutine + +from flyteidl.core.execution_pb2 import TaskExecution + +import flytekit +from flytekit.models.task import TaskTemplate + + +def mirror_async_methods(func: Callable, **kwargs) -> Coroutine: + if inspect.iscoroutinefunction(func): + return func(**kwargs) + args = [v for _, v in kwargs.items()] + return asyncio.get_running_loop().run_in_executor(None, func, *args) + + +def convert_to_flyte_phase(state: str) -> TaskExecution.Phase: + """ + Convert the state from the agent to the phase in flyte. + """ + state = state.lower() + # timedout is the state of Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate + if state in ["failed", "timeout", "timedout", "canceled"]: + return TaskExecution.FAILED + elif state in ["done", "succeeded", "success"]: + return TaskExecution.SUCCEEDED + elif state in ["running"]: + return TaskExecution.RUNNING + raise ValueError(f"Unrecognized state: {state}") + + +def is_terminal_phase(phase: TaskExecution.Phase) -> bool: + """ + Return true if the phase is terminal. + """ + return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED] + + +def get_agent_secret(secret_key: str) -> str: + return flytekit.current_context().secrets.get(secret_key) + + +def render_task_template(tt: TaskTemplate, file_prefix: str) -> TaskTemplate: + args = tt.container.args + for i in range(len(args)): + tt.container.args[i] = args[i].replace("{{.input}}", f"{file_prefix}/inputs.pb") + tt.container.args[i] = args[i].replace("{{.outputPrefix}}", f"{file_prefix}") + tt.container.args[i] = args[i].replace("{{.rawOutputDataPrefix}}", f"{file_prefix}/raw_output") + tt.container.args[i] = args[i].replace("{{.checkpointOutputPrefix}}", f"{file_prefix}/checkpoint_output") + tt.container.args[i] = args[i].replace("{{.prevCheckpointPrefix}}", f"{file_prefix}/prev_checkpoint") + return tt diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 62636d1420..44fe7e1f44 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -595,10 +595,14 @@ def from_flyte_idl(cls, pb2_object): class TaskNodeOverrides(_common.FlyteIdlEntity): def __init__( - self, resources: typing.Optional[Resources], extended_resources: typing.Optional[tasks_pb2.ExtendedResources] + self, + resources: typing.Optional[Resources], + extended_resources: typing.Optional[tasks_pb2.ExtendedResources], + container_image: typing.Optional[str] = None, ): self._resources = resources self._extended_resources = extended_resources + self._container_image = container_image @property def resources(self) -> Resources: @@ -608,19 +612,25 @@ def resources(self) -> Resources: def extended_resources(self) -> tasks_pb2.ExtendedResources: return self._extended_resources + @property + def container_image(self) -> typing.Optional[str]: + return self._container_image + def to_flyte_idl(self): return _core_workflow.TaskNodeOverrides( resources=self.resources.to_flyte_idl() if self.resources is not None else None, extended_resources=self.extended_resources, + container_image=self.container_image, ) @classmethod def from_flyte_idl(cls, pb2_object): resources = Resources.from_flyte_idl(pb2_object.resources) extended_resources = pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None + container_image = pb2_object.container_image if len(pb2_object.container_image) > 0 else None if bool(resources.requests) or bool(resources.limits): - return cls(resources=resources, extended_resources=extended_resources) - return cls(resources=None, extended_resources=extended_resources) + return cls(resources=resources, extended_resources=extended_resources, container_image=container_image) + return cls(resources=None, extended_resources=extended_resources, container_image=container_image) class TaskNode(_common.FlyteIdlEntity): diff --git a/flytekit/models/schedule.py b/flytekit/models/schedule.py index a6be2a58ee..65d3f477ac 100644 --- a/flytekit/models/schedule.py +++ b/flytekit/models/schedule.py @@ -1,13 +1,13 @@ -from flyteidl.admin import schedule_pb2 as _schedule_pb2 +from flyteidl.admin import schedule_pb2 -from flytekit.models import common as _common +from flytekit.models import common -class Schedule(_common.FlyteIdlEntity): +class Schedule(common.FlyteIdlEntity): class FixedRateUnit(object): - MINUTE = _schedule_pb2.MINUTE - HOUR = _schedule_pb2.HOUR - DAY = _schedule_pb2.DAY + MINUTE = schedule_pb2.MINUTE + HOUR = schedule_pb2.HOUR + DAY = schedule_pb2.DAY @classmethod def enum_to_string(cls, int_value): @@ -24,7 +24,7 @@ def enum_to_string(cls, int_value): else: return "{}".format(int_value) - class FixedRate(_common.FlyteIdlEntity): + class FixedRate(common.FlyteIdlEntity): def __init__(self, value, unit): """ :param int value: @@ -51,7 +51,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.schedule_pb2.FixedRate """ - return _schedule_pb2.FixedRate(value=self.value, unit=self.unit) + return schedule_pb2.FixedRate(value=self.value, unit=self.unit) @classmethod def from_flyte_idl(cls, pb2_object): @@ -61,7 +61,7 @@ def from_flyte_idl(cls, pb2_object): """ return cls(pb2_object.value, pb2_object.unit) - class CronSchedule(_common.FlyteIdlEntity): + class CronSchedule(common.FlyteIdlEntity): def __init__(self, schedule, offset): """ :param Text schedule: cron expression or aliases @@ -88,7 +88,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.schedule_pb2.FixedRate """ - return _schedule_pb2.CronSchedule(schedule=self.schedule, offset=self.offset) + return schedule_pb2.CronSchedule(schedule=self.schedule, offset=self.offset) @classmethod def from_flyte_idl(cls, pb2_object): @@ -145,7 +145,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.schedule_pb2.Schedule """ - return _schedule_pb2.Schedule( + return schedule_pb2.Schedule( kickoff_time_input_arg=self.kickoff_time_input_arg, cron_expression=self.cron_expression, rate=self.rate.to_flyte_idl() if self.rate is not None else None, diff --git a/flytekit/sensor/base_sensor.py b/flytekit/sensor/base_sensor.py index 0e40055ea5..3392f77009 100644 --- a/flytekit/sensor/base_sensor.py +++ b/flytekit/sensor/base_sensor.py @@ -1,26 +1,48 @@ import collections import inspect +import typing from abc import abstractmethod +from dataclasses import asdict, dataclass from typing import Any, Dict, Optional, TypeVar -import jsonpickle -from typing_extensions import get_type_hints +from typing_extensions import Protocol, get_type_hints, runtime_checkable from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin, ResourceMeta -T = TypeVar("T") -SENSOR_MODULE = "sensor_module" -SENSOR_NAME = "sensor_name" -SENSOR_CONFIG_PKL = "sensor_config_pkl" -INPUTS = "inputs" + +@runtime_checkable +class SensorConfig(Protocol): + def to_dict(self) -> typing.Dict[str, Any]: + """ + Serialize the sensor config to a dictionary. + """ + raise NotImplementedError + + @classmethod + def from_dict(cls, d: typing.Dict[str, Any]) -> "SensorConfig": + """ + Deserialize the sensor config from a dictionary. + """ + raise NotImplementedError + + +@dataclass +class SensorMetadata(ResourceMeta): + sensor_module: str + sensor_name: str + sensor_config: Optional[dict] = None + inputs: Optional[dict] = None + + +T = TypeVar("T", bound=SensorConfig) class BaseSensor(AsyncAgentExecutorMixin, PythonTask): """ - Base class for all sensors. Sensors are tasks that are designed to run forever, and periodically check for some + Base class for all sensors. Sensors are tasks that are designed to run forever and periodically check for some condition to be met. When the condition is met, the sensor will complete. Sensors are designed to be run by the sensor agent, and not by the Flyte engine. """ @@ -57,10 +79,9 @@ async def poke(self, **kwargs) -> bool: raise NotImplementedError def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - cfg = { - SENSOR_MODULE: type(self).__module__, - SENSOR_NAME: type(self).__name__, - } - if self._sensor_config is not None: - cfg[SENSOR_CONFIG_PKL] = jsonpickle.encode(self._sensor_config) - return cfg + sensor_config = self._sensor_config.to_dict() if self._sensor_config else None + return asdict( + SensorMetadata( + sensor_module=type(self).__module__, sensor_name=type(self).__name__, sensor_config=sensor_config + ) + ) diff --git a/flytekit/sensor/file_sensor.py b/flytekit/sensor/file_sensor.py index 2fb3d64ec1..f894546927 100644 --- a/flytekit/sensor/file_sensor.py +++ b/flytekit/sensor/file_sensor.py @@ -1,14 +1,10 @@ -from typing import Optional, TypeVar - from flytekit import FlyteContextManager from flytekit.sensor.base_sensor import BaseSensor -T = TypeVar("T") - class FileSensor(BaseSensor): - def __init__(self, name: str, config: Optional[T] = None, **kwargs): - super().__init__(name=name, sensor_config=config, **kwargs) + def __init__(self, name: str, **kwargs): + super().__init__(name=name, **kwargs) async def poke(self, path: str) -> bool: file_access = FlyteContextManager.current_context().file_access diff --git a/flytekit/sensor/sensor_engine.py b/flytekit/sensor/sensor_engine.py index 816360715a..ac718abe35 100644 --- a/flytekit/sensor/sensor_engine.py +++ b/flytekit/sensor/sensor_engine.py @@ -1,62 +1,49 @@ import importlib -import typing from typing import Optional -import cloudpickle -import jsonpickle -from flyteidl.admin.agent_pb2 import ( - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) from flyteidl.core.execution_pb2 import TaskExecution from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flytekit.sensor.base_sensor import INPUTS, SENSOR_CONFIG_PKL, SENSOR_MODULE, SENSOR_NAME +from flytekit.sensor.base_sensor import SensorMetadata -T = typing.TypeVar("T") - -class SensorEngine(AgentBase): +class SensorEngine(AsyncAgentBase): name = "Sensor" def __init__(self): - super().__init__(task_type="sensor") - - async def create( - self, output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs - ) -> CreateTaskResponse: - python_interface_inputs = { - name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() - } - ctx = FlyteContextManager.current_context() + super().__init__(task_type_name="sensor", metadata_type=SensorMetadata) + + async def create(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwarg) -> SensorMetadata: + sensor_metadata = SensorMetadata(**task_template.custom) + if inputs: + ctx = FlyteContextManager.current_context() + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() + } native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) - task_template.custom[INPUTS] = native_inputs - return CreateTaskResponse(resource_meta=cloudpickle.dumps(task_template.custom)) + sensor_metadata.inputs = native_inputs - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - meta = cloudpickle.loads(resource_meta) + return sensor_metadata - sensor_module = importlib.import_module(name=meta[SENSOR_MODULE]) - sensor_def = getattr(sensor_module, meta[SENSOR_NAME]) - sensor_config = jsonpickle.decode(meta[SENSOR_CONFIG_PKL]) if meta.get(SENSOR_CONFIG_PKL) else None + async def get(self, resource_meta: SensorMetadata, **kwargs) -> Resource: + sensor_module = importlib.import_module(name=resource_meta.sensor_module) + sensor_def = getattr(sensor_module, resource_meta.sensor_name) - inputs = meta.get(INPUTS, {}) + inputs = resource_meta.inputs cur_phase = ( TaskExecution.SUCCEEDED - if await sensor_def("sensor", config=sensor_config).poke(**inputs) + if await sensor_def("sensor", config=resource_meta.sensor_config).poke(**inputs) else TaskExecution.RUNNING ) - return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=None)) + return Resource(phase=cur_phase, outputs=None) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - return DeleteTaskResponse() + async def delete(self, resource_meta: SensorMetadata, **kwargs): + return AgentRegistry.register(SensorEngine()) diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 6d696bc4d6..2847ff1b3d 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -4,6 +4,8 @@ from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union +from flyteidl.admin import schedule_pb2 + from flytekit import PythonFunctionTask, SourceCode from flytekit.configuration import SerializationSettings from flytekit.core import constants as _common_constants @@ -368,12 +370,19 @@ def get_serializable_launch_plan( else: raw_prefix_config = entity.raw_output_data_config or _common_models.RawOutputDataConfig("") + if entity.trigger: + lc = entity.trigger.to_flyte_idl(entity) + if isinstance(lc, schedule_pb2.Schedule): + raise ValueError("Please continue to use the schedule arg, the trigger arg is not implemented yet") + else: + lc = None + lps = _launch_plan_models.LaunchPlanSpec( workflow_id=wf_id, entity_metadata=_launch_plan_models.LaunchPlanMetadata( schedule=entity.schedule, notifications=options.notifications or entity.notifications, - launch_conditions=entity.additional_metadata, + launch_conditions=lc, ), default_inputs=entity.parameters, fixed_inputs=entity.fixed_inputs, @@ -468,7 +477,11 @@ def get_serializable_node( output_aliases=[], task_node=workflow_model.TaskNode( reference_id=task_spec.template.id, - overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources), + overrides=TaskNodeOverrides( + resources=entity._resources, + extended_resources=entity._extended_resources, + container_image=entity._container_image, + ), ), ) if entity._aliases: @@ -545,7 +558,11 @@ def get_serializable_node( output_aliases=[], task_node=workflow_model.TaskNode( reference_id=entity.flyte_entity.id, - overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources), + overrides=TaskNodeOverrides( + resources=entity._resources, + extended_resources=entity._extended_resources, + container_image=entity._container_image, + ), ), ) elif isinstance(entity.flyte_entity, FlyteWorkflow): @@ -594,7 +611,11 @@ def get_serializable_array_node( task_spec = get_serializable(entity_mapping, settings, entity, options) task_node = workflow_model.TaskNode( reference_id=task_spec.template.id, - overrides=TaskNodeOverrides(resources=node._resources, extended_resources=node._extended_resources), + overrides=TaskNodeOverrides( + resources=node._resources, + extended_resources=node._extended_resources, + container_image=node._container_image, + ), ) node = workflow_model.Node( id=entity.name, diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 06b0c44a87..da6bc4d699 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -158,7 +158,7 @@ def new_remote_file(cls, name: typing.Optional[str] = None) -> FlyteFile: return cls(path=remote_path) def __class_getitem__(cls, item: typing.Union[str, typing.Type]) -> typing.Type[FlyteFile]: - from . import FileExt + from flytekit.types.file import FileExt if item is None: return cls diff --git a/plugins/flytekit-airflow/dev-requirements.txt b/plugins/flytekit-airflow/dev-requirements.txt index 114279f520..a3d41be209 100644 --- a/plugins/flytekit-airflow/dev-requirements.txt +++ b/plugins/flytekit-airflow/dev-requirements.txt @@ -636,7 +636,7 @@ opentelemetry-semantic-conventions==0.42b0 # via opentelemetry-sdk ordered-set==4.1.0 # via flask-limiter -orjson==3.9.10 +orjson==3.9.15 # via apache-beam overrides==6.5.0 # via google-cloud-pubsublite diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py b/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py index e52453d7bb..2ff0d0e9a5 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py @@ -5,12 +5,6 @@ import cloudpickle import jsonpickle -from flyteidl.admin.agent_pb2 import ( - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.airflow.task import AirflowObj, _get_airflow_instance @@ -21,13 +15,13 @@ from airflow.utils.context import Context from flytekit import logger from flytekit.exceptions.user import FlyteUserException -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @dataclass -class ResourceMetadata: +class AirflowMetadata(ResourceMeta): """ This class is used to store the Airflow task configuration. It is serialized and returned to FlytePropeller. """ @@ -37,8 +31,15 @@ class ResourceMetadata: airflow_trigger_callback: str = field(default=None) job_id: typing.Optional[str] = field(default=None) + def encode(self) -> bytes: + return cloudpickle.dumps(self) -class AirflowAgent(AgentBase): + @classmethod + def decode(cls, data: bytes) -> "AirflowMetadata": + return cloudpickle.loads(data) + + +class AirflowAgent(AsyncAgentBase): """ It is used to run Airflow tasks. It is registered as an agent in the AgentRegistry. There are three kinds of Airflow tasks: AirflowOperator, AirflowSensor, and AirflowHook. @@ -62,22 +63,18 @@ class AirflowAgent(AgentBase): name = "Airflow Agent" def __init__(self): - super().__init__(task_type="airflow") + super().__init__(task_type_name="airflow", metadata_type=AirflowMetadata) async def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> AirflowMetadata: airflow_obj = jsonpickle.decode(task_template.custom["task_config_pkl"]) airflow_instance = _get_airflow_instance(airflow_obj) - resource_meta = ResourceMetadata(airflow_operator=airflow_obj) + resource_meta = AirflowMetadata(airflow_operator=airflow_obj) if isinstance(airflow_instance, BaseOperator) and not isinstance(airflow_instance, BaseSensorOperator): try: - resource_meta = ResourceMetadata(airflow_operator=airflow_obj) + resource_meta = AirflowMetadata(airflow_operator=airflow_obj) airflow_instance.execute(context=Context()) except TaskDeferred as td: parameters = td.trigger.__dict__.copy() @@ -90,12 +87,13 @@ async def create( ) resource_meta.airflow_trigger_callback = td.method_name - return CreateTaskResponse(resource_meta=cloudpickle.dumps(resource_meta)) + return resource_meta - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - meta = cloudpickle.loads(resource_meta) - airflow_operator_instance = _get_airflow_instance(meta.airflow_operator) - airflow_trigger_instance = _get_airflow_instance(meta.airflow_trigger) if meta.airflow_trigger else None + async def get(self, resource_meta: AirflowMetadata, **kwargs) -> Resource: + airflow_operator_instance = _get_airflow_instance(resource_meta.airflow_operator) + airflow_trigger_instance = ( + _get_airflow_instance(resource_meta.airflow_trigger) if resource_meta.airflow_trigger else None + ) airflow_ctx = Context() message = None cur_phase = TaskExecution.RUNNING @@ -107,7 +105,7 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: if airflow_trigger_instance: try: # Airflow trigger returns immediately when - # 1. Failed to get the task status + # 1. Failed to get task status # 2. Task succeeded or failed # succeeded or failed: returns a TriggerEvent with payload # running: runs forever, so set a default timeout (2 seconds) here. @@ -115,7 +113,7 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: event = await asyncio.wait_for(airflow_trigger_instance.run().__anext__(), 2) try: # Trigger callback will check the status of the task in the payload, and raise AirflowException if failed. - trigger_callback = getattr(airflow_operator_instance, meta.airflow_trigger_callback) + trigger_callback = getattr(airflow_operator_instance, resource_meta.airflow_trigger_callback) trigger_callback(context=airflow_ctx, event=typing.cast(TriggerEvent, event).payload) cur_phase = TaskExecution.SUCCEEDED except AirflowException as e: @@ -136,10 +134,10 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: else: raise FlyteUserException("Only sensor and operator are supported.") - return GetTaskResponse(resource=Resource(phase=cur_phase, message=message)) + return Resource(phase=cur_phase, message=message) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - return DeleteTaskResponse() + async def delete(self, resource_meta: AirflowMetadata, **kwargs): + return AgentRegistry.register(AirflowAgent()) diff --git a/plugins/flytekit-airflow/setup.py b/plugins/flytekit-airflow/setup.py index 682cd72c18..09536d2e90 100644 --- a/plugins/flytekit-airflow/setup.py +++ b/plugins/flytekit-airflow/setup.py @@ -6,8 +6,8 @@ plugin_requires = [ "apache-airflow", - "flytekit>=1.9.0", - "flyteidl>=1.10.6", + "flytekit>1.10.7", + "flyteidl>1.10.7", ] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-airflow/tests/test_agent.py b/plugins/flytekit-airflow/tests/test_agent.py index dc4d167b10..57999d5c59 100644 --- a/plugins/flytekit-airflow/tests/test_agent.py +++ b/plugins/flytekit-airflow/tests/test_agent.py @@ -5,10 +5,9 @@ from airflow.operators.python import PythonOperator from airflow.sensors.bash import BashSensor from airflow.sensors.time_sensor import TimeSensor -from flyteidl.admin.agent_pb2 import DeleteTaskResponse from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.airflow import AirflowObj -from flytekitplugins.airflow.agent import AirflowAgent, ResourceMetadata +from flytekitplugins.airflow.agent import AirflowAgent, AirflowMetadata from flytekit import workflow from flytekit.interfaces.cli_identifiers import Identifier @@ -44,7 +43,7 @@ def test_resource_metadata(): parameters={"task_id": "id", "bash_command": "echo 'hello world'"}, ) trigger_cfg = AirflowObj(module="airflow.trigger.file", name="FileTrigger", parameters={"filepath": "file.txt"}) - meta = ResourceMetadata( + meta = AirflowMetadata( airflow_operator=task_cfg, airflow_trigger=trigger_cfg, airflow_trigger_callback="execute_complete", @@ -89,10 +88,9 @@ async def test_airflow_agent(): ) agent = AirflowAgent() - res = await agent.create("/tmp", dummy_template, None) - metadata = res.resource_meta - res = await agent.get(metadata) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.resource.message == "" + metadata = await agent.create(dummy_template, None) + resource = await agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.message is None res = await agent.delete(metadata) - assert res == DeleteTaskResponse() + assert res is None diff --git a/plugins/flytekit-aws-batch/setup.py b/plugins/flytekit-aws-batch/setup.py index db75ce18b9..423e439ba2 100644 --- a/plugins/flytekit-aws-batch/setup.py +++ b/plugins/flytekit-aws-batch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0"] +plugin_requires = ["flytekit>=1.3.0b2"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py index 4c34285793..0275162f72 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py @@ -1,25 +1,16 @@ import datetime -import json -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Dict, Optional -from flyteidl.admin.agent_pb2 import ( - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) -from flyteidl.core.execution_pb2 import TaskExecution +from flyteidl.core.execution_pb2 import TaskExecution, TaskLog from google.cloud import bigquery from flytekit import FlyteContextManager, StructuredDataset, logger from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase -from flytekit.models import literals -from flytekit.models.core.execution import TaskLog +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flytekit.models.types import LiteralType, StructuredDatasetType pythonTypeToBigQueryType: Dict[type, str] = { # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data_type_sizes @@ -34,25 +25,24 @@ @dataclass -class Metadata: +class BigQueryMetadata(ResourceMeta): job_id: str project: str location: str -class BigQueryAgent(AgentBase): +class BigQueryAgent(AsyncAgentBase[BigQueryMetadata]): name = "Bigquery Agent" def __init__(self): - super().__init__(task_type="bigquery_query_job_task") + super().__init__(task_type_name="bigquery_query_job_task", metadata_type=BigQueryMetadata) def create( self, - output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs, - ) -> CreateTaskResponse: + ) -> BigQueryMetadata: job_config = None if inputs: ctx = FlyteContextManager.current_context() @@ -73,54 +63,36 @@ def create( location = custom["Location"] client = bigquery.Client(project=project, location=location) query_job = client.query(task_template.sql.statement, job_config=job_config) - metadata = Metadata(job_id=str(query_job.job_id), location=location, project=project) - return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) + return BigQueryMetadata(job_id=str(query_job.job_id), location=location, project=project) - def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: + def get(self, resource_meta: BigQueryMetadata, **kwargs) -> Resource: client = bigquery.Client() - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - log_links = [ - TaskLog( - uri=f"https://console.cloud.google.com/bigquery?project={metadata.project}&j=bq:{metadata.location}:{metadata.job_id}&page=queryresults", - name="BigQuery Console", - ).to_flyte_idl() - ] - - job = client.get_job(metadata.job_id, metadata.project, metadata.location) + log_link = TaskLog( + uri=f"https://console.cloud.google.com/bigquery?project={resource_meta.project}&j=bq:{resource_meta.location}:{resource_meta.job_id}&page=queryresults", + name="BigQuery Console", + ) + + job = client.get_job(resource_meta.job_id, resource_meta.project, resource_meta.location) if job.errors: logger.error("failed to run BigQuery job with error:", job.errors.__str__()) - return GetTaskResponse( - resource=Resource(state=TaskExecution.FAILED, message=job.errors.__str__()), log_links=log_links - ) + return Resource(phase=TaskExecution.FAILED, message=job.errors.__str__(), log_links=[log_link]) cur_phase = convert_to_flyte_phase(str(job.state)) res = None if cur_phase == TaskExecution.SUCCEEDED: - ctx = FlyteContextManager.current_context() - if job.destination: - output_location = ( - f"bq://{job.destination.project}:{job.destination.dataset_id}.{job.destination.table_id}" - ) - res = literals.LiteralMap( - { - "results": TypeEngine.to_literal( - ctx, - StructuredDataset(uri=output_location), - StructuredDataset, - LiteralType(structured_dataset_type=StructuredDatasetType(format="")), - ) - } - ).to_flyte_idl() - - return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res), log_links=log_links) - - def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: + dst = job.destination + if dst: + ctx = FlyteContextManager.current_context() + output_location = f"bq://{dst.project}:{dst.dataset_id}.{dst.table_id}" + res = TypeEngine.dict_to_literal_map(ctx, {"results": StructuredDataset(uri=output_location)}) + + return Resource(phase=cur_phase, message=job.state, log_links=[log_link], outputs=res) + + def delete(self, resource_meta: BigQueryMetadata, **kwargs): client = bigquery.Client() - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - client.cancel_job(metadata.job_id, metadata.project, metadata.location) - return DeleteTaskResponse() + client.cancel_job(resource_meta.job_id, resource_meta.project, resource_meta.location) AgentRegistry.register(BigQueryAgent()) diff --git a/plugins/flytekit-bigquery/setup.py b/plugins/flytekit-bigquery/setup.py index 10dd3c7ca5..9f2dea65c0 100644 --- a/plugins/flytekit-bigquery/setup.py +++ b/plugins/flytekit-bigquery/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "google-cloud-bigquery", "flyteidl>=v1.10.6"] +plugin_requires = ["flytekit>1.10.7", "google-cloud-bigquery", "flyteidl>1.10.7"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-bigquery/tests/test_agent.py b/plugins/flytekit-bigquery/tests/test_agent.py index dc2af4ab80..5897b4b468 100644 --- a/plugins/flytekit-bigquery/tests/test_agent.py +++ b/plugins/flytekit-bigquery/tests/test_agent.py @@ -1,10 +1,8 @@ -import json -from dataclasses import asdict from datetime import timedelta from unittest import mock from flyteidl.core.execution_pb2 import TaskExecution -from flytekitplugins.bigquery.agent import Metadata +from flytekitplugins.bigquery.agent import BigQueryMetadata import flytekit.models.interface as interface_models from flytekit.extend.backend.base_agent import AgentRegistry @@ -86,20 +84,18 @@ def __init__(self): sql=Sql("SELECT 1"), ) - metadata_bytes = json.dumps( - asdict(Metadata(job_id="dummy_id", project="dummy_project", location="us-central1")) - ).encode("utf-8") - assert agent.create("/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes - res = agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED + metadata = BigQueryMetadata(job_id="dummy_id", project="dummy_project", location="us-central1") + assert agent.create(dummy_template, task_inputs) == metadata + resource = agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED assert ( - res.resource.outputs.literals["results"].scalar.structured_dataset.uri + resource.outputs.literals["results"].scalar.structured_dataset.uri == "bq://dummy_project:dummy_dataset.dummy_table" ) - assert res.log_links[0].name == "BigQuery Console" + assert resource.log_links[0].name == "BigQuery Console" assert ( - res.log_links[0].uri + resource.log_links[0].uri == "https://console.cloud.google.com/bigquery?project=dummy_project&j=bq:us-central1:dummy_id&page=queryresults" ) - agent.delete(metadata_bytes) + agent.delete(metadata) mock_instance.cancel_job.assert_called() diff --git a/plugins/flytekit-greatexpectations/setup.py b/plugins/flytekit-greatexpectations/setup.py index 0ef3fcf2fc..506dd4853b 100644 --- a/plugins/flytekit-greatexpectations/setup.py +++ b/plugins/flytekit-greatexpectations/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.5.0,<2.0.0", + "flytekit>=1.5.0", "great-expectations>=0.13.30,<=0.18.8", "sqlalchemy>=1.4.23,<2.0.0", "pyspark==3.3.1", diff --git a/plugins/flytekit-identity-aware-proxy/setup.py b/plugins/flytekit-identity-aware-proxy/setup.py index fa5d4ef4a0..c279dfe263 100644 --- a/plugins/flytekit-identity-aware-proxy/setup.py +++ b/plugins/flytekit-identity-aware-proxy/setup.py @@ -11,7 +11,7 @@ "flytekit>=1.10", # https://github.com/grpc/grpc/issues/33935 # https://github.com/grpc/grpc/issues/35323 - "grpcio<1.55.0", + "grpcio>=1.62.0", ] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-k8s-pod/setup.py b/plugins/flytekit-k8s-pod/setup.py index 9767c24ddb..1a3479805b 100644 --- a/plugins/flytekit-k8s-pod/setup.py +++ b/plugins/flytekit-k8s-pod/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.3.0b2,<2.0.0", + "flytekit>=1.3.0b2", "kubernetes>=12.0.1", ] diff --git a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py index 285be4e88b..e0dbceada2 100644 --- a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py +++ b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py @@ -1,30 +1,29 @@ import json import shlex import subprocess -from dataclasses import asdict, dataclass +from dataclasses import dataclass from tempfile import NamedTemporaryFile from typing import Optional -from flyteidl.admin.agent_pb2 import CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource from flytekitplugins.mmcloud.utils import async_check_output, mmcloud_status_to_flyte_phase from flytekit import current_context -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta from flytekit.loggers import logger from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @dataclass -class Metadata: +class MMCloudMetadata(ResourceMeta): job_id: str -class MMCloudAgent(AgentBase): +class MMCloudAgent(AsyncAgentBase): name = "MMCloud Agent" def __init__(self): - super().__init__(task_type="mmcloud_task", asynchronous=True) + super().__init__(task_type_name="mmcloud_task", metadata_type=MMCloudMetadata) self._response_format = ["--format", "json"] async def async_login(self): @@ -57,10 +56,10 @@ async def async_login(self): logger.info("Logged in to OpCenter") async def create( - self, output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs - ) -> CreateTaskResponse: + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> MMCloudMetadata: """ - Submit Flyte task as MMCloud job to the OpCenter, and return the job UID for the task. + Submit a Flyte task as MMCloud job to the OpCenter, and return the job UID for the task. """ submit_command = [ "float", @@ -128,16 +127,13 @@ async def create( logger.exception("Cannot open job script for writing") raise - metadata = Metadata(job_id=job_id) + return MMCloudMetadata(job_id=job_id) - return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) - - async def async_get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: + async def get(self, resource_meta: MMCloudMetadata, **kwargs) -> Resource: """ Return the status of the task, and return the outputs on success. """ - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - job_id = metadata.job_id + job_id = resource_meta.job_id show_command = [ "float", @@ -173,14 +169,13 @@ async def async_get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: logger.info(f"Obtained status for MMCloud job {job_id}: {job_status}") logger.debug(f"OpCenter response: {show_response}") - return GetTaskResponse(resource=Resource(phase=task_phase)) + return Resource(phase=task_phase) - async def async_delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: + async def delete(self, resource_meta: MMCloudMetadata, **kwargs): """ Delete the task. This call should be idempotent. """ - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - job_id = metadata.job_id + job_id = resource_meta.job_id cancel_command = [ "float", @@ -203,7 +198,5 @@ async def async_delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskRespon logger.info(f"Submitted cancel request for MMCloud job: {job_id}") - return DeleteTaskResponse() - AgentRegistry.register(MMCloudAgent()) diff --git a/plugins/flytekit-mmcloud/tests/test_mmcloud.py b/plugins/flytekit-mmcloud/tests/test_mmcloud.py index eff4c4e63c..79830e2c56 100644 --- a/plugins/flytekit-mmcloud/tests/test_mmcloud.py +++ b/plugins/flytekit-mmcloud/tests/test_mmcloud.py @@ -115,7 +115,7 @@ def say_hello0(name: str) -> str: assert isinstance(agent, MMCloudAgent) create_task_response = asyncio.run( - agent.async_create( + agent.create( context=context, output_prefix="", task_template=task_spec.template, @@ -124,13 +124,13 @@ def say_hello0(name: str) -> str: ) resource_meta = create_task_response.resource_meta - get_task_response = asyncio.run(agent.async_get(context=context, resource_meta=resource_meta)) + get_task_response = asyncio.run(agent.get(context=context, resource_meta=resource_meta)) phase = get_task_response.resource.phase assert phase in (TaskExecution.RUNNING, TaskExecution.SUCCEEDED) - asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta)) + asyncio.run(agent.delete(context=context, resource_meta=resource_meta)) - get_task_response = asyncio.run(agent.async_get(context=context, resource_meta=resource_meta)) + get_task_response = asyncio.run(agent.get(context=context, resource_meta=resource_meta)) phase = get_task_response.resource.phase assert phase == TaskExecution.FAILED @@ -146,7 +146,7 @@ def say_hello1(name: str) -> str: task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello1) with pytest.raises(subprocess.CalledProcessError): create_task_response = asyncio.run( - agent.async_create( + agent.create( context=context, output_prefix="", task_template=task_spec.template, @@ -165,7 +165,7 @@ def say_hello2(name: str) -> str: task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello2) with pytest.raises(subprocess.CalledProcessError): create_task_response = asyncio.run( - agent.async_create( + agent.create( context=context, output_prefix="", task_template=task_spec.template, @@ -183,7 +183,7 @@ def say_hello3(name: str) -> str: task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello3) create_task_response = asyncio.run( - agent.async_create( + agent.create( context=context, output_prefix="", task_template=task_spec.template, @@ -191,7 +191,7 @@ def say_hello3(name: str) -> str: ) ) resource_meta = create_task_response.resource_meta - asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta)) + asyncio.run(agent.delete(context=context, resource_meta=resource_meta)) @task( task_config=MMCloudConfig(), @@ -203,7 +203,7 @@ def say_hello4(name: str) -> str: task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello4) create_task_response = asyncio.run( - agent.async_create( + agent.create( context=context, output_prefix="", task_template=task_spec.template, @@ -211,7 +211,7 @@ def say_hello4(name: str) -> str: ) ) resource_meta = create_task_response.resource_meta - asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta)) + asyncio.run(agent.delete(context=context, resource_meta=resource_meta)) @task( task_config=MMCloudConfig(), @@ -222,7 +222,7 @@ def say_hello5(name: str) -> str: task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello5) create_task_response = asyncio.run( - agent.async_create( + agent.create( context=context, output_prefix="", task_template=task_spec.template, @@ -230,4 +230,4 @@ def say_hello5(name: str) -> str: ) ) resource_meta = create_task_response.resource_meta - asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta)) + asyncio.run(agent.delete(context=context, resource_meta=resource_meta)) diff --git a/plugins/flytekit-modin/setup.py b/plugins/flytekit-modin/setup.py index 7d62ea16fe..0a3394a2d0 100644 --- a/plugins/flytekit-modin/setup.py +++ b/plugins/flytekit-modin/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit<1.3.0b2,<2.0.0", + "flytekit", "modin[ray]>=0.13.0", "fsspec", ] diff --git a/plugins/flytekit-onnx-scikitlearn/setup.py b/plugins/flytekit-onnx-scikitlearn/setup.py index 45780ae174..fe55536066 100644 --- a/plugins/flytekit-onnx-scikitlearn/setup.py +++ b/plugins/flytekit-onnx-scikitlearn/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit<1.3.0b2,<2.0.0", "skl2onnx>=1.10.3", "networkx<3.2; python_version<'3.9'"] +plugin_requires = ["flytekit", "skl2onnx>=1.10.3", "networkx<3.2; python_version<'3.9'"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-openai/README.md b/plugins/flytekit-openai/README.md new file mode 100644 index 0000000000..21c5553ce7 --- /dev/null +++ b/plugins/flytekit-openai/README.md @@ -0,0 +1,44 @@ +# Flytekit ChatGPT Plugin +ChatGPT plugin allows you to run ChatGPT tasks in the Flyte workflow without changing any code. + +## Example +```python +from flytekit import task, workflow +from flytekitplugins.chatgpt import ChatGPTTask, ChatGPTConfig + +chatgpt_small_job = ChatGPTTask( + name="chatgpt gpt-3.5-turbo", + openai_organization="org-NayNG68kGnVXMJ8Ak4PMgQv7", + chatgpt_config={ + "model": "gpt-3.5-turbo", + "temperature": 0.7, + }, +) + +chatgpt_big_job = ChatGPTTask( + name="chatgpt gpt-4", + openai_organization="org-NayNG68kGnVXMJ8Ak4PMgQv7", + chatgpt_config={ + "model": "gpt-4", + "temperature": 0.7, + }, +) + + +@workflow +def wf(message: str) -> str: + message = chatgpt_small_job(message=message) + message = chatgpt_big_job(message=message) + return message + + +if __name__ == "__main__": + print(wf(message="hi")) +``` + + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-chatgpt +``` diff --git a/plugins/flytekit-openai/flytekitplugins/chatgpt/__init__.py b/plugins/flytekit-openai/flytekitplugins/chatgpt/__init__.py new file mode 100644 index 0000000000..64dd73fb35 --- /dev/null +++ b/plugins/flytekit-openai/flytekitplugins/chatgpt/__init__.py @@ -0,0 +1,12 @@ +""" +.. currentmodule:: flytekitplugins.chatgpt +This package contains things that are useful when extending Flytekit. +.. autosummary:: + :template: custom.rst + :toctree: generated/ + ChatGPTAgent + ChatGPTTask +""" + +from .agent import ChatGPTAgent +from .task import ChatGPTTask diff --git a/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py new file mode 100644 index 0000000000..afd3af1321 --- /dev/null +++ b/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py @@ -0,0 +1,52 @@ +import asyncio +import logging +from typing import Optional + +from flyteidl.core.execution_pb2 import TaskExecution + +from flytekit import FlyteContextManager, lazy_module +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import AgentRegistry, Resource, SyncAgentBase +from flytekit.extend.backend.utils import get_agent_secret +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +openai = lazy_module("openai") + +TIMEOUT_SECONDS = 10 +OPENAI_API_KEY = "FLYTE_OPENAI_API_KEY" + + +class ChatGPTAgent(SyncAgentBase): + name = "ChatGPT Agent" + + def __init__(self): + super().__init__(task_type_name="chatgpt") + + async def do( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> Resource: + ctx = FlyteContextManager.current_context() + input_python_value = TypeEngine.literal_map_to_kwargs(ctx, inputs, {"message": str}) + message = input_python_value["message"] + + custom = task_template.custom + custom["chatgpt_config"]["messages"] = [{"role": "user", "content": message}] + client = openai.AsyncOpenAI( + organization=custom["openai_organization"], + api_key=get_agent_secret(secret_key=OPENAI_API_KEY), + ) + + logger = logging.getLogger("httpx") + logger.setLevel(logging.WARNING) + + completion = await asyncio.wait_for(client.chat.completions.create(**custom["chatgpt_config"]), TIMEOUT_SECONDS) + message = completion.choices[0].message.content + outputs = {"o0": message} + + return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs) + + +AgentRegistry.register(ChatGPTAgent()) diff --git a/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py new file mode 100644 index 0000000000..c37a40650d --- /dev/null +++ b/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py @@ -0,0 +1,44 @@ +from typing import Any, Dict + +from flytekit.configuration import SerializationSettings +from flytekit.core.base_task import PythonTask +from flytekit.core.interface import Interface +from flytekit.extend.backend.base_agent import SyncAgentExecutorMixin + + +class ChatGPTTask(SyncAgentExecutorMixin, PythonTask): + """ + This is the simplest form of a ChatGPT Task, you can define the model and the input you want. + """ + + _TASK_TYPE = "chatgpt" + + def __init__(self, name: str, openai_organization: str, chatgpt_config: Dict[str, Any], **kwargs): + """ + Args: + name: Name of this task, should be unique in the project + openai_organization: OpenAI Organization. String can be found here. https://platform.openai.com/docs/api-reference/organization-optional + chatgpt_config: ChatGPT job configuration. Config structure can be found here. https://platform.openai.com/docs/api-reference/completions/create + """ + + if "model" not in chatgpt_config: + raise ValueError("The 'model' configuration variable is required in chatgpt_config") + + task_config = {"openai_organization": openai_organization, "chatgpt_config": chatgpt_config} + + inputs = {"message": str} + outputs = {"o0": str} + + super().__init__( + task_type=self._TASK_TYPE, + name=name, + task_config=task_config, + interface=Interface(inputs=inputs, outputs=outputs), + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + return { + "openai_organization": self.task_config["openai_organization"], + "chatgpt_config": self.task_config["chatgpt_config"], + } diff --git a/plugins/flytekit-openai/setup.py b/plugins/flytekit-openai/setup.py new file mode 100644 index 0000000000..82257c2435 --- /dev/null +++ b/plugins/flytekit-openai/setup.py @@ -0,0 +1,38 @@ +from setuptools import setup + +PLUGIN_NAME = "chatgpt" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>1.10.7", "openai>=1.12.0", "flyteidl>=1.11.0b0"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package holds the ChatGPT plugins for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-openai/tests/test_agent.py b/plugins/flytekit-openai/tests/test_agent.py new file mode 100644 index 0000000000..dd340bd1a7 --- /dev/null +++ b/plugins/flytekit-openai/tests/test_agent.py @@ -0,0 +1,69 @@ +from datetime import timedelta +from unittest import mock + +import pytest +from flyteidl.core.execution_pb2 import TaskExecution + +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals +from flytekit.models.core.identifier import ResourceType +from flytekit.models.literals import LiteralMap +from flytekit.models.task import RuntimeMetadata, TaskMetadata, TaskTemplate + + +async def mock_acreate(*args, **kwargs) -> str: + mock_response = mock.MagicMock() + mock_choice = mock.MagicMock() + mock_choice.message.content = "mocked_message" + mock_response.choices = [mock_choice] + return mock_response + + +@pytest.mark.asyncio +async def test_chatgpt_agent(): + agent = AgentRegistry.get_agent("chatgpt") + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_config = { + "openai_organization": "test-openai-orgnization-id", + "chatgpt_config": {"model": "gpt-3.5-turbo", "temperature": 0.7}, + } + task_metadata = TaskMetadata( + True, + RuntimeMetadata(RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) + tmp = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + interface=None, + type="chatgpt", + ) + + task_inputs = LiteralMap( + { + "message": literals.Literal( + scalar=literals.Scalar(primitive=literals.Primitive(string_value="Test ChatGPT Plugin")) + ), + }, + ) + message = "mocked_message" + mocked_token = "mocked_openai_api_key" + mocked_context = mock.patch("flytekit.current_context", autospec=True).start() + mocked_context.return_value.secrets.get.return_value = mocked_token + + with mock.patch("openai.resources.chat.completions.AsyncCompletions.create", new=mock_acreate): + # Directly await the coroutine without using asyncio.run + response = await agent.do(tmp, task_inputs) + + assert response.phase == TaskExecution.SUCCEEDED + assert response.outputs == {"o0": message} diff --git a/plugins/flytekit-openai/tests/test_chatgpt.py b/plugins/flytekit-openai/tests/test_chatgpt.py new file mode 100644 index 0000000000..f85f94cc7b --- /dev/null +++ b/plugins/flytekit-openai/tests/test_chatgpt.py @@ -0,0 +1,42 @@ +from collections import OrderedDict + +from flytekitplugins.chatgpt import ChatGPTTask + +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.extend import get_serializable +from flytekit.models.types import SimpleType + + +def test_chatgpt_task(): + chatgpt_task = ChatGPTTask( + name="chatgpt", + openai_organization="TEST ORGANIZATION ID", + chatgpt_config={ + "model": "gpt-3.5-turbo", + "temperature": 0.7, + }, + ) + + assert len(chatgpt_task.interface.inputs) == 1 + assert len(chatgpt_task.interface.outputs) == 1 + + default_img = Image(name="default", fqn="test", tag="tag") + serialization_settings = SerializationSettings( + project="proj", + domain="dom", + version="123", + image_config=ImageConfig(default_image=default_img, images=[default_img]), + env={}, + ) + + chatgpt_task_spec = get_serializable(OrderedDict(), serialization_settings, chatgpt_task) + custom = chatgpt_task_spec.template.custom + assert custom["openai_organization"] == "TEST ORGANIZATION ID" + assert custom["chatgpt_config"]["model"] == "gpt-3.5-turbo" + assert custom["chatgpt_config"]["temperature"] == 0.7 + + assert len(chatgpt_task_spec.template.interface.inputs) == 1 + assert len(chatgpt_task_spec.template.interface.outputs) == 1 + + assert chatgpt_task_spec.template.interface.inputs["message"].type.simple == SimpleType.STRING + assert chatgpt_task_spec.template.interface.outputs["o0"].type.simple == SimpleType.STRING diff --git a/plugins/flytekit-papermill/dev-requirements.in b/plugins/flytekit-papermill/dev-requirements.in index d0a9617bdb..3dc10d1afc 100644 --- a/plugins/flytekit-papermill/dev-requirements.in +++ b/plugins/flytekit-papermill/dev-requirements.in @@ -1,4 +1,4 @@ -flyteidl>=1.10.7b0 +-e file:../../.#egg=flytekit -e file:../../.#egg=flytekitplugins-pod&subdirectory=plugins/flytekit-k8s-pod -e file:../../.#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark -e file:../../.#egg=flytekitplugins-awsbatch&subdirectory=plugins/flytekit-aws-batch diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index d06bc68085..8cb38662e3 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -1,18 +1,12 @@ -import json -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Optional -from flyteidl.admin.agent_pb2 import ( - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) from flyteidl.core.execution_pb2 import TaskExecution from flytekit import FlyteContextManager, StructuredDataset, lazy_module, logger from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase from flytekit.models import literals from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -25,7 +19,7 @@ @dataclass -class Metadata: +class SnowflakeJobMetadata(ResourceMeta): user: str account: str database: str @@ -53,7 +47,7 @@ def get_private_key(): return pkb -def get_connection(metadata: Metadata) -> snowflake_connector: +def get_connection(metadata: SnowflakeJobMetadata) -> snowflake_connector: return snowflake_connector.connect( user=metadata.user, account=metadata.account, @@ -64,25 +58,18 @@ def get_connection(metadata: Metadata) -> snowflake_connector: ) -class SnowflakeAgent(AgentBase): +class SnowflakeAgent(AsyncAgentBase): def __init__(self): - super().__init__(task_type=TASK_TYPE) + super().__init__(task_type_name=TASK_TYPE, metadata_type=SnowflakeJobMetadata) async def create( - self, output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs - ) -> CreateTaskResponse: - params = None - if inputs: - ctx = FlyteContextManager.current_context() - python_interface_inputs = { - name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() - } - native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) - logger.info(f"Create Snowflake agent params with inputs: {native_inputs}") - params = native_inputs + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> SnowflakeJobMetadata: + ctx = FlyteContextManager.current_context() + literal_types = task_template.interface.inputs + params = TypeEngine.literal_map_to_kwargs(ctx, inputs, literal_types=literal_types) if inputs else None config = task_template.config - conn = snowflake_connector.connect( user=config["user"], account=config["account"], @@ -95,7 +82,7 @@ async def create( cs = conn.cursor() cs.execute_async(task_template.sql.statement, params=params) - metadata = Metadata( + return SnowflakeJobMetadata( user=config["user"], account=config["account"], database=config["database"], @@ -105,22 +92,19 @@ async def create( query_id=str(cs.sfqid), ) - return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) - - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - conn = get_connection(metadata) + async def get(self, resource_meta: SnowflakeJobMetadata, **kwargs) -> Resource: + conn = get_connection(resource_meta) try: - query_status = conn.get_query_status_throw_if_error(metadata.query_id) + query_status = conn.get_query_status_throw_if_error(resource_meta.query_id) except snowflake_connector.ProgrammingError as err: logger.error("Failed to get snowflake job status with error:", err.msg) - return GetTaskResponse(resource=Resource(state=TaskExecution.FAILED)) + return Resource(phase=TaskExecution.FAILED) cur_phase = convert_to_flyte_phase(str(query_status.name)) res = None if cur_phase == TaskExecution.SUCCEEDED: ctx = FlyteContextManager.current_context() - output_metadata = f"snowflake://{metadata.user}:{metadata.account}/{metadata.warehouse}/{metadata.database}/{metadata.schema}/{metadata.table}" + output_metadata = f"snowflake://{resource_meta.user}:{resource_meta.account}/{resource_meta.warehouse}/{resource_meta.database}/{resource_meta.schema}/{resource_meta.table}" res = literals.LiteralMap( { "results": TypeEngine.to_literal( @@ -132,19 +116,17 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: } ).to_flyte_idl() - return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res)) + return Resource(phase=cur_phase, outputs=res) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - conn = get_connection(metadata) + async def delete(self, resource_meta: SnowflakeJobMetadata, **kwargs): + conn = get_connection(resource_meta) cs = conn.cursor() try: - cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')") + cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{resource_meta.query_id}')") cs.fetchall() finally: cs.close() conn.close() - return DeleteTaskResponse() AgentRegistry.register(SnowflakeAgent()) diff --git a/plugins/flytekit-snowflake/setup.py b/plugins/flytekit-snowflake/setup.py index 527daa2486..b5265c299e 100644 --- a/plugins/flytekit-snowflake/setup.py +++ b/plugins/flytekit-snowflake/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "snowflake-connector-python>=3.1.0"] +plugin_requires = ["flytekit>1.10.7", "snowflake-connector-python>=3.1.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-snowflake/tests/test_agent.py b/plugins/flytekit-snowflake/tests/test_agent.py index 017297704e..f3dcb0686d 100644 --- a/plugins/flytekit-snowflake/tests/test_agent.py +++ b/plugins/flytekit-snowflake/tests/test_agent.py @@ -1,13 +1,10 @@ -import json -from dataclasses import asdict from datetime import timedelta from unittest import mock from unittest.mock import MagicMock import pytest -from flyteidl.admin.agent_pb2 import DeleteTaskResponse from flyteidl.core.execution_pb2 import TaskExecution -from flytekitplugins.snowflake.agent import Metadata +from flytekitplugins.snowflake.agent import SnowflakeJobMetadata import flytekit.models.interface as interface_models from flytekit import lazy_module @@ -30,8 +27,11 @@ async def test_snowflake_agent(mock_get_private_key): mock_conn_instance = snowflake_connector.connect.return_value mock_conn_instance.get_query_status_throw_if_error.return_value = query_status_mock - agent = AgentRegistry.get_agent("snowflake") + mock_cursor = MagicMock() + mock_cursor.sfqid = "dummy_id" + mock_conn_instance.cursor.return_value = mock_cursor + agent = AgentRegistry.get_agent("snowflake") task_id = Identifier( resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" ) @@ -82,32 +82,28 @@ async def test_snowflake_agent(mock_get_private_key): sql=Sql("SELECT 1"), ) - metadata = Metadata( + snowflake_metadata = SnowflakeJobMetadata( user="dummy_user", account="dummy_account", table="dummy_table", database="dummy_database", schema="dummy_schema", warehouse="dummy_warehouse", - query_id="dummy_query_id", + query_id="dummy_id", ) - res = await agent.create("/tmp", dummy_template, task_inputs) - metadata.query_id = Metadata(**json.loads(res.resource_meta.decode("utf-8"))).query_id - metadata_bytes = json.dumps(asdict(metadata)).encode("utf-8") - assert res.resource_meta == metadata_bytes + metadata = await agent.create(dummy_template, task_inputs) + assert metadata == snowflake_metadata - res = await agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED + resource = await agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED assert ( - res.resource.outputs.literals["results"].scalar.structured_dataset.uri + resource.outputs.literals["results"].scalar.structured_dataset.uri == "snowflake://dummy_user:dummy_account/dummy_warehouse/dummy_database/dummy_schema/dummy_table" ) - delete_response = await agent.delete(metadata_bytes) - - # Assert the response - assert isinstance(delete_response, DeleteTaskResponse) + delete_response = await agent.delete(snowflake_metadata) + assert delete_response is None # Verify that the expected methods were called on the mock cursor mock_cursor = mock_conn_instance.cursor.return_value diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 2fe442182a..8200263ac3 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -1,15 +1,14 @@ import http import json -import pickle import typing from dataclasses import dataclass from typing import Optional -from flyteidl.admin.agent_pb2 import CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource from flyteidl.core.execution_pb2 import TaskExecution from flytekit import lazy_module -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase, get_agent_secret +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -20,24 +19,20 @@ @dataclass -class Metadata: +class DatabricksJobMetadata(ResourceMeta): databricks_instance: str run_id: str -class DatabricksAgent(AgentBase): +class DatabricksAgent(AsyncAgentBase): name = "Databricks Agent" def __init__(self): - super().__init__(task_type="spark", asynchronous=True) + super().__init__(task_type_name="spark", metadata_type=DatabricksJobMetadata) async def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> DatabricksJobMetadata: custom = task_template.custom container = task_template.container databricks_job = custom["databricksConf"] @@ -72,21 +67,18 @@ async def create( if resp.status != http.HTTPStatus.OK: raise Exception(f"Failed to create databricks job with error: {response}") - metadata = Metadata( - databricks_instance=databricks_instance, - run_id=str(response["run_id"]), - ) - return CreateTaskResponse(resource_meta=pickle.dumps(metadata)) + return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"])) - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - metadata = pickle.loads(resource_meta) - databricks_instance = metadata.databricks_instance - databricks_url = f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={metadata.run_id}" + async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource: + databricks_instance = resource_meta.databricks_instance + databricks_url = ( + f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={resource_meta.run_id}" + ) async with aiohttp.ClientSession() as session: async with session.get(databricks_url, headers=get_header()) as resp: if resp.status != http.HTTPStatus.OK: - raise Exception(f"Failed to get databricks job {metadata.run_id} with error: {resp.reason}") + raise Exception(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}") response = await resp.json() cur_phase = TaskExecution.RUNNING @@ -99,25 +91,21 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: message = state["state_message"] job_id = response.get("job_id") - databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{metadata.run_id}" + databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{resource_meta.run_id}" log_links = [TaskLog(uri=databricks_console_url, name="Databricks Console").to_flyte_idl()] - return GetTaskResponse(resource=Resource(phase=cur_phase, message=message), log_links=log_links) + return Resource(phase=cur_phase, message=message, log_links=log_links) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - metadata = pickle.loads(resource_meta) - - databricks_url = f"https://{metadata.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel" - data = json.dumps({"run_id": metadata.run_id}) + async def delete(self, resource_meta: DatabricksJobMetadata, **kwargs): + databricks_url = f"https://{resource_meta.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel" + data = json.dumps({"run_id": resource_meta.run_id}) async with aiohttp.ClientSession() as session: async with session.post(databricks_url, headers=get_header(), data=data) as resp: if resp.status != http.HTTPStatus.OK: - raise Exception(f"Failed to cancel databricks job {metadata.run_id} with error: {resp.reason}") + raise Exception(f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}") await resp.json() - return DeleteTaskResponse() - def get_header() -> typing.Dict[str, str]: token = get_agent_secret("FLYTE_DATABRICKS_ACCESS_TOKEN") diff --git a/plugins/flytekit-spark/setup.py b/plugins/flytekit-spark/setup.py index ac7b650ecb..4bc8983289 100644 --- a/plugins/flytekit-spark/setup.py +++ b/plugins/flytekit-spark/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0", "aiohttp", "flyteidl>=1.10.0", "pandas"] +plugin_requires = ["flytekit>1.10.7", "pyspark>=3.0.0", "aiohttp", "flyteidl>1.10.7", "pandas"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py index 5d19b3402f..80f91c5c76 100644 --- a/plugins/flytekit-spark/tests/test_agent.py +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -1,12 +1,11 @@ import http -import pickle from datetime import timedelta from unittest import mock import pytest from aioresponses import aioresponses from flyteidl.core.execution_pb2 import TaskExecution -from flytekitplugins.spark.agent import DATABRICKS_API_ENDPOINT, Metadata, get_header +from flytekitplugins.spark.agent import DATABRICKS_API_ENDPOINT, DatabricksJobMetadata, get_header from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.interfaces.cli_identifiers import Identifier @@ -103,11 +102,9 @@ async def test_databricks_agent(): mocked_context = mock.patch("flytekit.current_context", autospec=True).start() mocked_context.return_value.secrets.get.return_value = mocked_token - metadata_bytes = pickle.dumps( - Metadata( - databricks_instance="test-account.cloud.databricks.com", - run_id="123", - ) + databricks_metadata = DatabricksJobMetadata( + databricks_instance="test-account.cloud.databricks.com", + run_id="123", ) mock_create_response = {"run_id": "123"} @@ -118,19 +115,19 @@ async def test_databricks_agent(): delete_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/cancel" with aioresponses() as mocked: mocked.post(create_url, status=http.HTTPStatus.OK, payload=mock_create_response) - res = await agent.create("/tmp", dummy_template, None) - assert res.resource_meta == metadata_bytes + res = await agent.create(dummy_template, None) + assert res == databricks_metadata mocked.get(get_url, status=http.HTTPStatus.OK, payload=mock_get_response) - res = await agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl() - assert res.resource.message == "OK" - assert res.log_links[0].name == "Databricks Console" - assert res.log_links[0].uri == "https://test-account.cloud.databricks.com/#job/1/run/123" + resource = await agent.get(databricks_metadata) + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.outputs is None + assert resource.message == "OK" + assert resource.log_links[0].name == "Databricks Console" + assert resource.log_links[0].uri == "https://test-account.cloud.databricks.com/#job/1/run/123" mocked.post(delete_url, status=http.HTTPStatus.OK, payload=mock_delete_response) - await agent.delete(metadata_bytes) + await agent.delete(databricks_metadata) assert get_header() == {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"} diff --git a/pyproject.toml b/pyproject.toml index 1f83e2b4e3..07d75cf00d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,8 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0,<7.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.10.7", + "flyteidl>1.10.7", + "flyteidl>=1.11.0b0", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index f92489e9c4..a1c137dc48 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -26,9 +26,10 @@ @pytest.fixture(scope="session") def register(): - subprocess.run( + out = subprocess.run( [ "pyflyte", + "--verbose", "-c", CONFIG, "register", @@ -43,6 +44,7 @@ def register(): MODULE_PATH, ] ) + assert out.returncode == 0 def test_fetch_execute_launch_plan(register): @@ -52,7 +54,7 @@ def test_fetch_execute_launch_plan(register): assert execution.outputs["o0"] == "hello world" -def fetch_execute_launch_plan_with_args(register): +def test_fetch_execute_launch_plan_with_args(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) flyte_launch_plan = remote.fetch_launch_plan(name="basic.basic_workflow.my_wf", version=VERSION) execution = remote.execute(flyte_launch_plan, inputs={"a": 10, "b": "foobar"}, wait=True) diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 40bb864c4f..4bcecde6a7 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -305,4 +305,4 @@ def my_mappable_task(a: int) -> typing.Optional[str]: def wf(x: typing.List[int]): array_node_map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image") - assert wf.nodes[0].run_entity.container_image == "random:image" + assert wf.nodes[0]._container_image == "random:image" diff --git a/tests/flytekit/unit/core/test_artifacts.py b/tests/flytekit/unit/core/test_artifacts.py index abd9e456c2..b3fc5b5f64 100644 --- a/tests/flytekit/unit/core/test_artifacts.py +++ b/tests/flytekit/unit/core/test_artifacts.py @@ -194,7 +194,7 @@ def test_query_basic(): partition_keys=["region"], ) data_query = aa.query(time_partition=Inputs.dt, region=Inputs.blah) - assert data_query.bindings == [] + assert data_query.binding is None assert data_query.artifact is aa dq_idl = data_query.to_flyte_idl() assert dq_idl.HasField("artifact_id") @@ -271,6 +271,28 @@ def wf2(a: CustomReturn = wf_artifact): assert aq.artifact_id.partitions.value["region"].static_value == "LAX" +def test_query_basic_query_bindings(): + # Note these bindings don't really work yet. + aa = Artifact( + name="ride_count_data", + time_partitioned=True, + partition_keys=["region"], + ) + bb = Artifact( + name="driver_data", + time_partitioned=True, + partition_keys=["region"], + ) + cc = Artifact( + name="passenger_data", + time_partitioned=True, + partition_keys=["region"], + ) + aa.query(time_partition=Inputs.dt, region=bb.partitions.region) + with pytest.raises(ValueError): + aa.query(time_partition=cc.time_partition, region=bb.partitions.region) + + def test_partition_none(): # confirm that we can distinguish between partitions being set to empty, and not being set # though this is not currently used. diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index c87d4c6b1f..2ae716d4b7 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -352,7 +352,7 @@ def my_mappable_task(a: int) -> typing.Optional[str]: def wf(x: typing.List[int]): map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image") - assert wf.nodes[0].flyte_entity.run_task.container_image == "random:image" + assert wf.nodes[0]._container_image == "random:image" def test_bounded_inputs_vars_order(serialization_settings): diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index df16ddd244..56eb82aa1d 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -465,7 +465,7 @@ def wf() -> str: bar().with_overrides(container_image="hello/world") return "hi" - assert wf.nodes[0].flyte_entity.container_image == "hello/world" + assert wf.nodes[0]._container_image == "hello/world" def test_override_accelerator(): diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 3bdfa15114..85c88def45 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -1,111 +1,108 @@ -import asyncio -import json import typing from collections import OrderedDict -from dataclasses import asdict, dataclass +from dataclasses import dataclass from unittest.mock import MagicMock, patch import grpc import pytest from flyteidl.admin.agent_pb2 import ( + CreateRequestHeader, CreateTaskRequest, - CreateTaskResponse, DeleteTaskRequest, - DeleteTaskResponse, + ExecuteTaskSyncRequest, + GetAgentRequest, GetTaskRequest, - GetTaskResponse, ListAgentsRequest, ListAgentsResponse, - Resource, + TaskCategory, ) -from flyteidl.core.execution_pb2 import TaskExecution +from flyteidl.core.execution_pb2 import TaskExecution, TaskLog from flytekit import PythonFunctionTask, task from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings -from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService +from flytekit.core.base_task import PythonTask, kwtypes +from flytekit.core.interface import Interface +from flytekit.exceptions.system import FlyteAgentNotFound +from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService, SyncAgentService from flytekit.extend.backend.base_agent import ( - AgentBase, AgentRegistry, + AsyncAgentBase, AsyncAgentExecutorMixin, - convert_to_flyte_phase, - get_agent_secret, + Resource, + ResourceMeta, + SyncAgentBase, + SyncAgentExecutorMixin, is_terminal_phase, render_task_template, ) +from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret from flytekit.models import literals -from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate from flytekit.tools.translator import get_serializable dummy_id = "dummy_id" -loop = asyncio.get_event_loop() @dataclass -class Metadata: +class DummyMetadata(ResourceMeta): job_id: str -class DummyAgent(AgentBase): +class DummyAgent(AsyncAgentBase): name = "Dummy Agent" def __init__(self): - super().__init__(task_type="dummy", asynchronous=False) + super().__init__(task_type_name="dummy", metadata_type=DummyMetadata) - def create( - self, output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs - ) -> CreateTaskResponse: - return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")) + def create(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap], **kwargs) -> DummyMetadata: + return DummyMetadata(job_id=dummy_id) - def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - return GetTaskResponse( - resource=Resource(phase=TaskExecution.SUCCEEDED), - log_links=[TaskLog(name="console", uri="localhost:3000").to_flyte_idl()], - ) + def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource: + return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")]) - def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - return DeleteTaskResponse() + def delete(self, resource_meta: DummyMetadata, **kwargs): + ... -class AsyncDummyAgent(AgentBase): +class AsyncDummyAgent(AsyncAgentBase): name = "Async Dummy Agent" def __init__(self): - super().__init__(task_type="async_dummy") + super().__init__(task_type_name="async_dummy", metadata_type=DummyMetadata) async def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: - return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")) + self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs + ) -> DummyMetadata: + return DummyMetadata(job_id=dummy_id) + + async def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource: + return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")]) + + async def delete(self, resource_meta: DummyMetadata, **kwargs): + ... - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - return GetTaskResponse(resource=Resource(phase=TaskExecution.SUCCEEDED)) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - return DeleteTaskResponse() +class MockOpenAIAgent(SyncAgentBase): + name = "mock openAI Agent" + def __init__(self): + super().__init__(task_type_name="openai") + + def do(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs) -> Resource: + assert inputs.literals["a"].scalar.primitive.integer == 1 + return Resource(phase=TaskExecution.SUCCEEDED, outputs={"o0": 1}) -class SyncDummyAgent(AgentBase): - name = "Sync Dummy Agent" + +class MockAsyncOpenAIAgent(SyncAgentBase): + name = "mock async openAI Agent" def __init__(self): - super().__init__(task_type="sync_dummy") + super().__init__(task_type_name="async_openai") - def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: - return CreateTaskResponse( - resource=Resource(phase=TaskExecution.SUCCEEDED, outputs=LiteralMap({}).to_flyte_idl()) - ) + async def do(self, task_template: TaskTemplate, inputs: LiteralMap = None, **kwargs) -> Resource: + assert inputs.literals["a"].scalar.primitive.integer == 1 + return Resource(phase=TaskExecution.SUCCEEDED, outputs={"o0": 1}) def get_task_template(task_type: str) -> TaskTemplate: @@ -134,113 +131,149 @@ def simple_task(i: int): ) -dummy_template = get_task_template("dummy") -async_dummy_template = get_task_template("async_dummy") -sync_dummy_template = get_task_template("sync_dummy") - - def test_dummy_agent(): - AgentRegistry.register(DummyAgent()) + AgentRegistry.register(DummyAgent(), override=True) agent = AgentRegistry.get_agent("dummy") - metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - assert agent.create("/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes - res = agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert agent.delete(metadata_bytes) == DeleteTaskResponse() + template = get_task_template("dummy") + metadata = DummyMetadata(job_id=dummy_id) + assert agent.create(template, task_inputs) == DummyMetadata(job_id=dummy_id) + resource = agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.log_links[0].name == "console" + assert resource.log_links[0].uri == "localhost:3000" + assert agent.delete(metadata) is None class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): - super().__init__( - task_type="dummy", - **kwargs, - ) + super().__init__(task_type="dummy", **kwargs) t = DummyTask(task_config={}, task_function=lambda: None, container_image="dummy") t.execute() t._task_type = "non-exist-type" - with pytest.raises(Exception, match="Cannot find agent for task type: non-exist-type."): + with pytest.raises(Exception, match="Cannot find agent for task category: non-exist-type."): t.execute() - agent_metadata = AgentRegistry.get_agent_metadata("Dummy Agent") - assert agent_metadata.name == "Dummy Agent" - assert agent_metadata.supported_task_types == ["dummy"] - +@pytest.mark.parametrize("agent", [DummyAgent(), AsyncDummyAgent()], ids=["sync", "async"]) @pytest.mark.asyncio -async def test_async_dummy_agent(): - AgentRegistry.register(AsyncDummyAgent()) - agent = AgentRegistry.get_agent("async_dummy") - metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - res = await agent.create("/tmp", async_dummy_template, task_inputs) +async def test_async_agent_service(agent): + AgentRegistry.register(agent, override=True) + service = AsyncAgentService() + ctx = MagicMock(spec=grpc.ServicerContext) + + inputs_proto = task_inputs.to_flyte_idl() + output_prefix = "/tmp" + metadata_bytes = DummyMetadata(job_id=dummy_id).encode() + + tmp = get_task_template(agent.task_category.name).to_flyte_idl() + task_category = TaskCategory(name=agent.task_category.name, version=0) + req = CreateTaskRequest(inputs=inputs_proto, output_prefix=output_prefix, template=tmp) + + res = await service.CreateTask(req, ctx) assert res.resource_meta == metadata_bytes - res = await agent.get(metadata_bytes) + res = await service.GetTask(GetTaskRequest(task_category=task_category, resource_meta=metadata_bytes), ctx) assert res.resource.phase == TaskExecution.SUCCEEDED - res = await agent.delete(metadata_bytes) - assert res == DeleteTaskResponse() + res = await service.DeleteTask(DeleteTaskRequest(task_category=task_category, resource_meta=metadata_bytes), ctx) + assert res is None - agent_metadata = AgentRegistry.get_agent_metadata("Async Dummy Agent") - assert agent_metadata.name == "Async Dummy Agent" - assert agent_metadata.supported_task_types == ["async_dummy"] + agent_metadata = AgentRegistry.get_agent_metadata(agent.name) + assert agent_metadata.supported_task_types[0] == agent.task_category.name + assert agent_metadata.supported_task_categories[0].name == agent.task_category.name + with pytest.raises(FlyteAgentNotFound): + AgentRegistry.get_agent_metadata("non-exist-namr") -@pytest.mark.asyncio -async def test_sync_dummy_agent(): - AgentRegistry.register(SyncDummyAgent()) - agent = AgentRegistry.get_agent("sync_dummy") - res = agent.create("/tmp", sync_dummy_template, task_inputs) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.resource.outputs == LiteralMap({}).to_flyte_idl() - agent_metadata = AgentRegistry.get_agent_metadata("Sync Dummy Agent") - assert agent_metadata.name == "Sync Dummy Agent" - assert agent_metadata.supported_task_types == ["sync_dummy"] +def test_register_agent(): + agent = DummyAgent() + AgentRegistry.register(agent, override=True) + assert AgentRegistry.get_agent("dummy").name == agent.name + + with pytest.raises(ValueError, match="Duplicate agent for task type: dummy_v0"): + AgentRegistry.register(agent) + + with pytest.raises(FlyteAgentNotFound): + AgentRegistry.get_agent("non-exist-type") + + agents = AgentRegistry.list_agents() + assert len(agents) >= 1 @pytest.mark.asyncio -async def run_agent_server(): - service = AsyncAgentService() +async def test_agent_metadata_service(): + agent = DummyAgent() + AgentRegistry.register(agent, override=True) + ctx = MagicMock(spec=grpc.ServicerContext) - request = CreateTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() - ) - async_request = CreateTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() - ) - sync_request = CreateTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=sync_dummy_template.to_flyte_idl() - ) - fake_agent = "fake" - metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") + metadata_service = AgentMetadataService() + res = await metadata_service.ListAgents(ListAgentsRequest(), ctx) + assert isinstance(res, ListAgentsResponse) + res = await metadata_service.GetAgent(GetAgentRequest(name="Dummy Agent"), ctx) + assert res.agent.name == agent.name + assert res.agent.supported_task_types[0] == agent.task_category.name + assert res.agent.supported_task_categories[0].name == agent.task_category.name - res = await service.CreateTask(request, ctx) - assert res.resource_meta == metadata_bytes - res = await service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) - assert res.resource.phase == TaskExecution.SUCCEEDED - res = await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) - assert isinstance(res, DeleteTaskResponse) - res = await service.CreateTask(async_request, ctx) - assert res.resource_meta == metadata_bytes - res = await service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) - assert res.resource.phase == TaskExecution.SUCCEEDED - res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) - assert isinstance(res, DeleteTaskResponse) +def test_openai_agent(): + AgentRegistry.register(MockOpenAIAgent(), override=True) - res = await service.CreateTask(sync_request, ctx) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.resource.outputs == LiteralMap({}).to_flyte_idl() + class OpenAITask(SyncAgentExecutorMixin, PythonTask): + def __init__(self, **kwargs): + super().__init__( + task_type="openai", interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)), **kwargs + ) - res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx) - assert res is None + t = OpenAITask(task_config={}, name="openai task") + res = t(a=1) + assert res == 1 + + +def test_async_openai_agent(): + AgentRegistry.register(MockAsyncOpenAIAgent(), override=True) + + class OpenAITask(SyncAgentExecutorMixin, PythonTask): + def __init__(self, **kwargs): + super().__init__( + task_type="async_openai", + interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)), + **kwargs, + ) + + t = OpenAITask(task_config={}, name="openai task") + res = t(a=1) + assert res == 1 + + +async def get_request_iterator(task_type: str): + inputs_proto = task_inputs.to_flyte_idl() + template = get_task_template(task_type).to_flyte_idl() + header = CreateRequestHeader(template=template, output_prefix="/tmp") + yield ExecuteTaskSyncRequest(header=header) + yield ExecuteTaskSyncRequest(inputs=inputs_proto) - metadata_service = AgentMetadataService() - res = await metadata_service.ListAgent(ListAgentsRequest(), ctx) - assert isinstance(res, ListAgentsResponse) +@pytest.mark.asyncio +async def test_sync_agent_service(): + AgentRegistry.register(MockOpenAIAgent(), override=True) + ctx = MagicMock(spec=grpc.ServicerContext) + + service = SyncAgentService() + res = await service.ExecuteTaskSync(get_request_iterator("openai"), ctx).__anext__() + assert res.header.resource.phase == TaskExecution.SUCCEEDED + assert res.header.resource.outputs.literals["o0"].scalar.primitive.integer == 1 + + +@pytest.mark.asyncio +async def test_sync_agent_service_with_asyncio(): + AgentRegistry.register(MockAsyncOpenAIAgent(), override=True) + AgentRegistry.register(DummyAgent(), override=True) + ctx = MagicMock(spec=grpc.ServicerContext) -def test_agent_server(): - loop.run_in_executor(None, run_agent_server) + service = SyncAgentService() + res = await service.ExecuteTaskSync(get_request_iterator("async_openai"), ctx).__anext__() + assert res.header.resource.phase == TaskExecution.SUCCEEDED + assert res.header.resource.outputs.literals["o0"].scalar.primitive.integer == 1 def test_is_terminal_phase(): @@ -274,7 +307,8 @@ def test_get_agent_secret(mocked_context): def test_render_task_template(): - tt = render_task_template(dummy_template, "s3://becket") + template = get_task_template("dummy") + tt = render_task_template(template, "s3://becket") assert tt.container.args == [ "pyflyte-fast-execute", "--additional-distribution", @@ -286,7 +320,7 @@ def test_render_task_template(): "--inputs", "s3://becket/inputs.pb", "--output-prefix", - "s3://becket/output", + "s3://becket", "--raw-output-data-prefix", "s3://becket/raw_output", "--checkpoint-path", diff --git a/tests/flytekit/unit/sensor/test_file_sensor.py b/tests/flytekit/unit/sensor/test_file_sensor.py index f6a50836be..bb0553dc27 100644 --- a/tests/flytekit/unit/sensor/test_file_sensor.py +++ b/tests/flytekit/unit/sensor/test_file_sensor.py @@ -16,7 +16,12 @@ def test_sensor_task(): env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) - assert sensor.get_custom(settings) == {"sensor_module": "flytekit.sensor.file_sensor", "sensor_name": "FileSensor"} + assert sensor.get_custom(settings) == { + "sensor_module": "flytekit.sensor.file_sensor", + "sensor_name": "FileSensor", + "sensor_config": None, + "inputs": None, + } tmp_file = tempfile.NamedTemporaryFile() @task() diff --git a/tests/flytekit/unit/sensor/test_sensor_engine.py b/tests/flytekit/unit/sensor/test_sensor_engine.py index b5353b61b4..4a12aed877 100644 --- a/tests/flytekit/unit/sensor/test_sensor_engine.py +++ b/tests/flytekit/unit/sensor/test_sensor_engine.py @@ -1,20 +1,20 @@ import tempfile +from dataclasses import asdict -import cloudpickle import pytest -from flyteidl.admin.agent_pb2 import DeleteTaskResponse from flyteidl.core.execution_pb2 import TaskExecution import flytekit.models.interface as interface_models from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.models import literals, types from flytekit.sensor import FileSensor -from flytekit.sensor.base_sensor import SENSOR_MODULE, SENSOR_NAME +from flytekit.sensor.base_sensor import SensorMetadata from tests.flytekit.unit.extend.test_agent import get_task_template @pytest.mark.asyncio async def test_sensor_engine(): + file = tempfile.NamedTemporaryFile() interfaces = interface_models.TypedInterface( { "path": interface_models.Variable(types.LiteralType(types.SimpleType.STRING), "description1"), @@ -22,12 +22,10 @@ async def test_sensor_engine(): {}, ) tmp = get_task_template("sensor") - tmp._custom = { - SENSOR_MODULE: FileSensor.__module__, - SENSOR_NAME: FileSensor.__name__, - } - file = tempfile.NamedTemporaryFile() - + sensor_metadata = SensorMetadata( + sensor_module=FileSensor.__module__, sensor_name=FileSensor.__name__, inputs={"path": file.name} + ) + tmp._custom = asdict(sensor_metadata) tmp._interface = interfaces task_inputs = literals.LiteralMap( @@ -37,11 +35,10 @@ async def test_sensor_engine(): ) agent = AgentRegistry.get_agent("sensor") - res = await agent.create("/tmp", tmp, task_inputs) + res = await agent.create(tmp, task_inputs) - metadata_bytes = cloudpickle.dumps(tmp.custom) - assert res.resource_meta == metadata_bytes - res = await agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED - res = await agent.delete(metadata_bytes) - assert res == DeleteTaskResponse() + assert res == sensor_metadata + resource = await agent.get(sensor_metadata) + assert resource.phase == TaskExecution.SUCCEEDED + res = await agent.delete(sensor_metadata) + assert res is None