diff --git a/frontend/src/context/workspaces/repositories.tsx b/frontend/src/context/workspaces/repositories.tsx index e5b5f3c1..a1b5dec3 100644 --- a/frontend/src/context/workspaces/repositories.tsx +++ b/frontend/src/context/workspaces/repositories.tsx @@ -16,12 +16,14 @@ import { toast } from "react-toastify"; import { createCustomContext } from "utils"; import { useWorkspaces } from "./workspaces"; +import _default from "vite-tsconfig-paths"; export interface IPiecesContext { repositories: Repository[]; defaultRepositories: Repository[]; repositoryPieces: PiecesRepository; repositoriesLoading: boolean; + defaultRepositoriesIsLoading: boolean; selectedRepositoryId?: number; setSelectedRepositoryId: React.Dispatch< @@ -58,7 +60,8 @@ const PiecesProvider: React.FC<{ children: React.ReactNode }> = ({ const queryClient = useQueryClient(); - const { data: defaultRepositories } = useRepositories({ + const { data: defaultRepositories, isLoading: defaultRepositoriesIsLoading } = useRepositories({ + workspaceId: workspace?.id, source: "default", }); @@ -104,11 +107,11 @@ const PiecesProvider: React.FC<{ children: React.ReactNode }> = ({ const repositoryPiecesAux: PiecesRepository = {}; const foragePieces: PieceForageSchema = {}; - if (!pieces?.length) { + if (!pieces?.length && !_defaultPieces?.length) { localStorage.setItem("pieces", foragePieces); return repositoryPiecesAux; } else { - for (const piece of pieces) { + for (const piece of [...pieces, ..._defaultPieces]) { if (repositoryPiecesAux[piece.repository_id]?.length) { repositoryPiecesAux[piece.repository_id].push(piece); } else { @@ -121,7 +124,7 @@ const PiecesProvider: React.FC<{ children: React.ReactNode }> = ({ localStorage.setItem("pieces", foragePieces); return repositoryPiecesAux; } - }, [pieces]); + }, [pieces, _defaultPieces]); const fetchForagePieceById = useCallback((id: number) => { const pieces = localStorage.getItem("pieces"); @@ -135,6 +138,7 @@ const PiecesProvider: React.FC<{ children: React.ReactNode }> = ({ defaultRepositories: defaultRepositories?.data ?? [], repositoryPieces, repositoriesLoading, + defaultRepositoriesIsLoading, selectedRepositoryId, setSelectedRepositoryId, diff --git a/frontend/src/features/workflowEditor/components/Drawers/PiecesDrawer/sidebarAddNode.tsx b/frontend/src/features/workflowEditor/components/Drawers/PiecesDrawer/sidebarAddNode.tsx index 23d6046e..bb136d9e 100644 --- a/frontend/src/features/workflowEditor/components/Drawers/PiecesDrawer/sidebarAddNode.tsx +++ b/frontend/src/features/workflowEditor/components/Drawers/PiecesDrawer/sidebarAddNode.tsx @@ -23,7 +23,20 @@ interface Props { } const SidebarAddNode: FC = ({ setOrientation, orientation }) => { - const { repositories, repositoriesLoading, repositoryPieces } = usesPieces(); + const { + repositories, + repositoriesLoading, + repositoryPieces, + defaultRepositories, + defaultRepositoriesIsLoading + } = usesPieces(); + + const controlRepository = useMemo(() => { + return defaultRepositories.find((repository) => { + return repository.name.includes("control"); + }); + }, [defaultRepositories]); + const [filter, setFilter] = useState(""); const [expanded, setExpanded] = useState>({}); @@ -71,12 +84,18 @@ const SidebarAddNode: FC = ({ setOrientation, orientation }) => { } }, [filter]); + const isLoading = useMemo(() =>{ + return repositoriesLoading || defaultRepositoriesIsLoading; + }, [repositoriesLoading, defaultRepositoriesIsLoading]) + + const allRepos = controlRepository ? [...repositories, controlRepository] : repositories; + return ( - {repositoriesLoading && ( + {isLoading && ( Loading repositories... )} - {!repositoriesLoading && ( + {!isLoading && ( = ({ setOrientation, orientation }) => { variant="filled" label="search" /> - {!repositoriesLoading && - repositories.map((repo) => { + {!isLoading && + allRepos.map((repo) => { if (!filteredRepositoryPieces[repo.id]?.length) { return null; } diff --git a/frontend/src/features/workflowEditor/components/Panel/WorkflowPanel/BatchNode/index.tsx b/frontend/src/features/workflowEditor/components/Panel/WorkflowPanel/BatchNode/index.tsx new file mode 100644 index 00000000..c95a13f1 --- /dev/null +++ b/frontend/src/features/workflowEditor/components/Panel/WorkflowPanel/BatchNode/index.tsx @@ -0,0 +1,111 @@ +import { Paper, Typography, useTheme } from "@mui/material"; +import { type CSSProperties, memo, useMemo } from "react"; +import { Handle, Position } from "reactflow"; +import { Icon } from "@iconify/react"; + +import { type DefaultNodeProps } from "../types"; + +export const BatchNode = memo(({ id, data, selected }) => { + const theme = useTheme(); + + const handleStyle = useMemo( + () => ({ + border: 0, + borderRadius: "16px", + backgroundColor: theme.palette.info.main, + transition: "ease 100", + zIndex: 2, + width: "12px", + height: "12px", + }), + [theme.palette.info.main] + ); + + const nodeStyle = useMemo( + () => ({ + display: "flex", + flexDirection: "column", + alignItems: "center", + justifyContent: "center", + position: "relative", + padding: 1, + textAlign: "center", + width: "100px", // Adjusted width to make it a square + height: "100px", // Adjusted height to make it a square + lineHeight: "60px", + border: selected ? `2px solid ${theme.palette.info.dark}` : "2px solid transparent", // Border color change based on selection + color: theme.palette.getContrastText(theme.palette.background.paper), + backgroundColor: theme.palette.background.paper, + borderRadius: "3px", + }), + [selected, theme.palette] + ); + + const iconStyle = useMemo( + () => ({ + width: "50px", // Adjusted width of the icon + height: "50px", // Adjusted height of the icon + marginBottom: "10px", // Added margin bottom to space out from text + }), + [] + ); + + const labelStyle = { + fontSize: "8px", + position: "absolute", + textAlign: "left", + justifyContent: "left", + right: "-25px", + top: "10px" + }; + + return ( + <> + connection.sourceHandle !== `target-${id}`} + /> + + DONE + + + BATCH + + + + Batch Piece + + + + + ); +}); + +BatchNode.displayName = "BatchNode"; + +export default BatchNode; diff --git a/frontend/src/features/workflowEditor/components/Panel/WorkflowPanel/WorkflowPanel.tsx b/frontend/src/features/workflowEditor/components/Panel/WorkflowPanel/WorkflowPanel.tsx index af41fb11..69c9f051 100644 --- a/frontend/src/features/workflowEditor/components/Panel/WorkflowPanel/WorkflowPanel.tsx +++ b/frontend/src/features/workflowEditor/components/Panel/WorkflowPanel/WorkflowPanel.tsx @@ -47,6 +47,7 @@ import { import { CustomConnectionLine } from "./ConnectionLine"; import DefaultEdge from "./DefaultEdge"; import { CustomNode } from "./DefaultNode"; +import BatchNode from "./BatchNode"; // Import the BatchNode import { type DefaultNode } from "./types"; const getId = (module_name: string) => { @@ -55,6 +56,7 @@ const getId = (module_name: string) => { const DEFAULT_NODE_TYPES: NodeTypes = { CustomNode, + BatchNode, }; const EDGE_TYPES: EdgeTypes = { @@ -106,7 +108,8 @@ const WorkflowPanel = forwardRef( setInstance(instance); const edges = getWorkflowEdges(); const nodes = getWorkflowNodes(); - setNodes(nodes); + + setNodes([...nodes, defaultBatchNode]); setEdges(edges); window.requestAnimationFrame(() => instance.fitView()); }, @@ -126,9 +129,14 @@ const WorkflowPanel = forwardRef( orientation: data?.orientation ?? "horizontal", }; + let nodeType = "CustomNode"; + if (data.repository_url === 'domino-default/default_control_repository') { + nodeType = "BatchNode"; + } + const newNode = { id: getId(data.id), - type: "CustomNode", + type: nodeType, position, data: newNodeData, }; diff --git a/rest/auth/permission_authorizer.py b/rest/auth/permission_authorizer.py index 03f78ab4..194d5aa1 100644 --- a/rest/auth/permission_authorizer.py +++ b/rest/auth/permission_authorizer.py @@ -1,8 +1,7 @@ from fastapi import HTTPException, Security from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -import jwt from schemas.errors.base import ForbiddenError, ResourceNotFoundError -from schemas.context.auth_context import AuthorizationContextData, WorkspaceAuthorizerData +from schemas.context.auth_context import WorkspaceAuthorizerData from database.models.enums import Permission, UserWorkspaceStatus from typing import Optional, Dict from auth.base_authorizer import BaseAuthorizer diff --git a/rest/constants/default_pieces/control/__init__.py b/rest/constants/default_pieces/control/__init__.py new file mode 100644 index 00000000..b2fa29e3 --- /dev/null +++ b/rest/constants/default_pieces/control/__init__.py @@ -0,0 +1,5 @@ +from .batch import BatchPiece + +DEFAULT_CONTROL_PIECES = [ + BatchPiece +] \ No newline at end of file diff --git a/rest/constants/default_pieces/control/batch.py b/rest/constants/default_pieces/control/batch.py new file mode 100644 index 00000000..9b0c6a84 --- /dev/null +++ b/rest/constants/default_pieces/control/batch.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel, Field, PositiveInt +from typing import List, Union, Optional +from datetime import datetime + + + +class InputModel(BaseModel): + batch_over: List[Union[str, int, float, bool, dict, datetime]] = Field( + title='Batch Over', + description='List to iterate over', + json_schema_extra={"from_upstream": "always"} + ) + max_concurrency: PositiveInt = Field( + title='Max Concurrency', + description='Max number of parallel executions' + ) + +class OutputMode(BaseModel): + output: List[Union[str, int, float, bool, dict, datetime]] = Field( + title='Output', + description='Output of the batch', + ) # TODO use output modifier? + +batch_piece_default_style = { + "label": "Batch Piece", + "iconClassName": "ic:baseline-loop", + "nodeType": "control", +} +class BatchPiece(BaseModel): + name: str = Field(title='Name', default='BatchPiece') + description: str = Field(title='Description', default='Piece to run batch processing with concurrency.') + input_schema: dict = Field(default=InputModel.model_json_schema()) + secrets_schema: Optional[dict] = Field(default=None) + style: dict = Field(default=batch_piece_default_style) + diff --git a/rest/core/settings.py b/rest/core/settings.py index d051a48d..4f3a71a1 100644 --- a/rest/core/settings.py +++ b/rest/core/settings.py @@ -93,6 +93,14 @@ class Settings(BaseSettings): version="0.0.1", url="domino-default/default_storage_repository" ) + DEFAULT_CONTROL_REPOSITORY: dict = dict( + name="default_control_repository", + path="default_control_repository", + source=getattr(RepositorySource, 'default').value, + version="0.0.1", + label="Default Control Repository", + url="domino-default/default_control_repository" + ) DEPLOY_MODE: str = os.environ.get('DOMINO_DEPLOY_MODE', 'local-k8s') diff --git a/rest/database/models/piece_repository.py b/rest/database/models/piece_repository.py index d9b97e92..2ec04bba 100644 --- a/rest/database/models/piece_repository.py +++ b/rest/database/models/piece_repository.py @@ -2,14 +2,14 @@ from sqlalchemy import Column, String, Integer, DateTime, Enum, JSON, ForeignKey from database.models.enums import RepositorySource from sqlalchemy.orm import relationship -from datetime import datetime +from datetime import datetime, timezone class PieceRepository(Base, BaseDatabaseModel): __tablename__ = "piece_repository" id = Column(Integer, primary_key=True) - created_at = Column(DateTime(timezone=True), nullable=False, default=datetime.utcnow) + created_at = Column(DateTime(timezone=True), nullable=False, default=datetime.now(tz=timezone.utc)) name = Column(String(50), unique=False) label = Column(String(50), unique=False) source = Column(Enum(RepositorySource), nullable=True, default=RepositorySource.github.value) diff --git a/rest/services/piece_repository_service.py b/rest/services/piece_repository_service.py index 8dd5f5c6..a1266054 100644 --- a/rest/services/piece_repository_service.py +++ b/rest/services/piece_repository_service.py @@ -2,7 +2,7 @@ import json import tomli from math import ceil -from datetime import datetime +from datetime import datetime, timezone from core.logger import get_configured_logger from schemas.context.auth_context import AuthorizationContextData from schemas.requests.piece_repository import CreateRepositoryRequest, PatchRepositoryRequest, ListRepositoryFilters @@ -29,6 +29,7 @@ from core.settings import settings + class PieceRepositoryService(object): def __init__(self) -> None: self.logger = get_configured_logger(self.__class__.__name__) @@ -132,7 +133,7 @@ def patch_piece_repository( version=piece_repository_data.version ) new_repo = PieceRepository( - created_at=datetime.utcnow(), + created_at=datetime.now(tz=timezone.utc), name=repository_files_metadata['config_toml'].get('repository').get('REPOSITORY_NAME'), source=repository.source, path=repository.path, @@ -176,6 +177,25 @@ def patch_piece_repository( return PatchRepositoryResponse(**repository.to_dict()) + def create_default_control_repository(self, workspace_id: int): + self.logger.info(f"Creating default control repository") + new_repo = PieceRepository( + name=settings.DEFAULT_CONTROL_REPOSITORY['name'], + created_at=datetime.now(tz=timezone.utc), + workspace_id=workspace_id, + label=settings.DEFAULT_CONTROL_REPOSITORY['label'], + path=settings.DEFAULT_CONTROL_REPOSITORY['path'], + source=settings.DEFAULT_CONTROL_REPOSITORY['source'], + version=settings.DEFAULT_CONTROL_REPOSITORY['version'], + url=settings.DEFAULT_CONTROL_REPOSITORY['url'] + ) + default_control_repository = self.piece_repository_repository.create(piece_repository=new_repo) + + self.piece_service.create_default_control_pieces( + piece_repository_id=default_control_repository.id, + ) + return default_control_repository + def create_default_storage_repository(self, workspace_id: int): """ Create default storage repository for workspace. @@ -185,7 +205,7 @@ def create_default_storage_repository(self, workspace_id: int): new_repo = PieceRepository( name=settings.DEFAULT_STORAGE_REPOSITORY['name'], - created_at=datetime.utcnow(), + created_at=datetime.now(tz=timezone.utc), workspace_id=workspace_id, path=settings.DEFAULT_STORAGE_REPOSITORY['path'], source=settings.DEFAULT_STORAGE_REPOSITORY['source'], @@ -228,7 +248,7 @@ def create_piece_repository( github_access_token=token ) new_repo = PieceRepository( - created_at=datetime.utcnow(), + created_at=datetime.now(tz=timezone.utc), name=repository_files_metadata['config_toml'].get('repository').get('REPOSITORY_NAME'), source=piece_repository_data.source, path=piece_repository_data.path, diff --git a/rest/services/piece_service.py b/rest/services/piece_service.py index a0e51cc2..c5282e0b 100644 --- a/rest/services/piece_service.py +++ b/rest/services/piece_service.py @@ -12,6 +12,7 @@ from schemas.responses.piece import GetPiecesResponse from utils.base_node_style import get_frontend_node_style from constants.default_pieces.storage import DEFAULT_STORAGE_PIECES +from constants.default_pieces.control import DEFAULT_CONTROL_PIECES class PieceService(object): @@ -127,6 +128,25 @@ def _update_pieces_from_metadata(self, piece_metadata: dict, dependencies_map: d return self.piece_repository.update(new_piece, piece_id=db_piece.id) + def create_default_control_pieces(self, piece_repository_id: int = 1) -> None: + """Create default control pieces in database + """ + self.logger.info("Creating default control pieces") + + pieces = [] + for piece_model in DEFAULT_CONTROL_PIECES: + model = piece_model() + piece = Piece( + name=model.name, + description=model.description, + secrets_schema=model.secrets_schema, + input_schema=model.input_schema, + repository_id=piece_repository_id, + style=model.style + ) + pieces.append(piece) + pieces = self.piece_repository.create_many(pieces) + return pieces def create_default_storage_pieces(self, piece_repository_id: int = 1) -> None: """Create default storage pieces in database diff --git a/rest/services/workflow_service.py b/rest/services/workflow_service.py index 99fafbf6..269227cf 100644 --- a/rest/services/workflow_service.py +++ b/rest/services/workflow_service.py @@ -78,11 +78,11 @@ def create_workflow( new_workflow = Workflow( name=body.workflow.name, uuid_name=workflow_id, - created_at=datetime.now(), + created_at=datetime.now(tz=timezone.utc), schema=body.forageSchema, ui_schema=body.ui_schema.model_dump(), created_by=auth_context.user_id, - last_changed_at=datetime.now(), + last_changed_at=datetime.now(tz=timezone.utc), start_date=body.workflow.start_date, end_date=body.workflow.end_date, schedule=body.workflow.schedule, diff --git a/rest/services/workspace_service.py b/rest/services/workspace_service.py index 4b9a5eed..2a3a35d0 100644 --- a/rest/services/workspace_service.py +++ b/rest/services/workspace_service.py @@ -68,6 +68,9 @@ def create_workspace( self.piece_repository_service.create_default_storage_repository( workspace_id=workspace.id ) + self.piece_repository_service.create_default_control_repository( + workspace_id=workspace.id + ) auth_context.workspace = WorkspaceAuthorizerData( id=workspace.id, name=workspace.name, diff --git a/src/domino/custom_operators/base_operator.py b/src/domino/custom_operators/base_operator.py new file mode 100644 index 00000000..50cabc98 --- /dev/null +++ b/src/domino/custom_operators/base_operator.py @@ -0,0 +1,91 @@ +from domino.custom_operators.docker_operator import DominoDockerOperator +from domino.custom_operators.k8s_operator import DominoKubernetesPodOperator +import os +from typing import Optional, Dict +from domino.schemas.shared_storage import WorkflowSharedStorage + + +class DominoBaseOperator: + def __new__( + cls, + dag_id: str, + task_id: str, + piece_name: str, + piece_source_image: str, + deploy_mode: str, + repository_url: str, + repository_version: str, + workspace_id: int, + piece_input_kwargs: Optional[Dict] = None, + workflow_shared_storage: Optional[WorkflowSharedStorage] = None, + container_resources: Optional[Dict] = None, + ): + if deploy_mode == "local-compose": + cls.operator_kwargs = dict( + entrypoint=["domino", "run-piece-docker"], + do_xcom_push=True, + mount_tmp_dir=False, + tty=True, + xcom_all=False, + retrieve_output=True, + retrieve_output_path="/airflow/xcom/return.out" + ) + return DominoDockerOperator( + dag_id=dag_id, + task_id=task_id, + piece_name=piece_name, + deploy_mode=deploy_mode, + repository_url=repository_url, + repository_version=repository_version, + workspace_id=workspace_id, + piece_input_kwargs=piece_input_kwargs, + workflow_shared_storage=workflow_shared_storage, + container_resources=container_resources or {}, + image=piece_source_image, + # ----------------- Docker ----------------- + # TODO uncoment + **cls.operator_kwargs + ) + elif deploy_mode in ["local-k8s", "local-k8s-dev", "prod", "k8s"]: + cls.operator_kwargs = dict( + namespace="default", + image_pull_policy="IfNotPresent", + name=f"airflow-worker-pod-{task_id}", + startup_timeout_seconds=600, + annotations={ + "sidecar.istio.io/inject": "false" + }, # TODO - remove this when istio is working with airflow k8s pod + # cmds=["/bin/bash"], + # arguments=["-c", "sleep 120;"], + cmds=["domino"], + arguments=["run-piece-k8s"], + do_xcom_push=True, + in_cluster=True + ) + return DominoKubernetesPodOperator( + dag_id=dag_id, + task_id=task_id, + piece_name=piece_name, + deploy_mode=deploy_mode, + repository_url=repository_url, + repository_version=repository_version, + workspace_id=workspace_id, + piece_input_kwargs=piece_input_kwargs, + workflow_shared_storage=workflow_shared_storage, + container_resources=container_resources or {}, + image=piece_source_image, + # ----------------- Kubernetes ----------------- + **cls.operator_kwargs + ) + else: + raise Exception(f"Invalid deploy mode: {deploy_mode}") + + @classmethod + def partial(cls, **kwargs): + deploy_mode = os.environ.get("DOMINO_DEPLOY_MODE") + if deploy_mode == "local-compose": + return DominoDockerOperator.partial(**kwargs, **cls.operator_kwargs) + elif deploy_mode in ["local-k8s", "local-k8s-dev", "prod", "k8s"]: + return DominoKubernetesPodOperator.partial(**kwargs, **cls.operator_kwargs) + else: + raise Exception(f"Invalid deploy mode: {deploy_mode}") \ No newline at end of file diff --git a/src/domino/custom_operators/docker_operator.py b/src/domino/custom_operators/docker_operator.py index 9ed0f75f..232a3111 100644 --- a/src/domino/custom_operators/docker_operator.py +++ b/src/domino/custom_operators/docker_operator.py @@ -2,11 +2,11 @@ from airflow.utils.context import Context from typing import Dict, Optional, Any import os - from domino.client.domino_backend_client import DominoBackendRestClient from domino.schemas import WorkflowSharedStorage, StorageSource -from docker.types import Mount import docker + + class DominoDockerOperator(DockerOperator): def __init__( @@ -21,6 +21,7 @@ def __init__( piece_input_kwargs: Optional[Dict] = None, workflow_shared_storage: WorkflowSharedStorage = None, container_resources: Optional[Dict] = None, + test_test: Optional[Dict] = None, **docker_operator_kwargs ) -> None: self.task_id = task_id @@ -55,7 +56,7 @@ def __init__( mounts = [] # TODO remove - used in DEV only ####################### - dev_pieces = False + dev_pieces = True if dev_pieces: piece_repo_name = repository_url.split("/")[-1] #local_repos_path = f"/mnt/shared_storage/Github/{piece_repo_name}" @@ -130,12 +131,29 @@ def _get_piece_kwargs_value_from_upstream_xcom( self, value: Any ): - if isinstance(value, dict) and value.get("type") == "fromUpstream": + if ( + isinstance(value, dict) + and value.get("type") == "fromUpstream" + and "batch_task_group" not in value.get("upstream_task_id") + ): upstream_task_id = value["upstream_task_id"] - output_arg = value["output_arg"] + upstream_output_arg = value["output_arg"] # upstream output arg if upstream_task_id not in self.shared_storage_upstream_ids_list: self.shared_storage_upstream_ids_list.append(upstream_task_id) - return self.upstream_xcoms_data[upstream_task_id][output_arg] + return self.upstream_xcoms_data[upstream_task_id][upstream_output_arg] + + if ( + isinstance(value, dict) + and value.get("type") == "fromUpstream" + and "batch_task_group" in value.get("upstream_task_id") + ): + upstream_task_id: str = value["upstream_task_id"] + upstream_output_arg: str = value["output_arg"] # upstream output arg + dynamic_mapped_upstream_xcom_data: list = self.upstream_xcoms_data[upstream_task_id] + value = [] + for e in dynamic_mapped_upstream_xcom_data: + value.append(e[upstream_output_arg]) + return value elif isinstance(value, list): return [self._get_piece_kwargs_value_from_upstream_xcom(item) for item in value] elif isinstance(value, dict): diff --git a/src/domino/task.py b/src/domino/task.py index 1d31d0f1..07196608 100644 --- a/src/domino/task.py +++ b/src/domino/task.py @@ -1,13 +1,7 @@ from airflow import DAG -from airflow.models import BaseOperator -from datetime import datetime from typing import Callable import os - -from domino.custom_operators.k8s_operator import DominoKubernetesPodOperator -from domino.custom_operators.docker_operator import DominoDockerOperator -from domino.custom_operators.python_operator import PythonOperator -from domino.custom_operators.worker_operator import DominoWorkerOperator +from domino.custom_operators.base_operator import DominoBaseOperator from domino.logger import get_configured_logger from domino.schemas import shared_storage_map, StorageSource @@ -26,7 +20,8 @@ def __init__( piece_input_kwargs: dict, workflow_shared_storage: dict = None, container_resources: dict = None, - **kwargs + init_operator: bool = True, + #**kwargs ) -> None: # Task configuration and attributes self.task_id = task_id @@ -39,10 +34,6 @@ def __init__( self.repository_version = piece["repository_version"] self.piece = piece self.piece_input_kwargs = piece_input_kwargs - if "execution_mode" not in self.piece: - self.execution_mode = "docker" - else: - self.execution_mode = self.piece["execution_mode"] # Shared storage if not workflow_shared_storage: @@ -57,84 +48,17 @@ def __init__( else: self.workflow_shared_storage = shared_storage_map[shared_storage_source_name] - # Container resources self.container_resources = container_resources - - # Get deploy mode self.deploy_mode = os.environ.get('DOMINO_DEPLOY_MODE') # Set up task operator - self._task_operator = self._set_operator() - - def _set_operator(self) -> BaseOperator: - """ - Set Airflow Operator according to deploy mode and Piece execution mode. - """ - if self.execution_mode == "worker": - return DominoWorkerOperator( - dag_id=self.dag_id, - task_id=self.task_id, - piece_name=self.piece.get('name'), - repository_name=self.piece.get('repository_name'), - workflow_id=self.piece.get('workflow_id'), - piece_input_kwargs=self.piece_input_kwargs, - ) - - if self.deploy_mode == "local-python": - return PythonOperator( - dag=self.dag, - task_id=self.task_id, - start_date=datetime(2021, 1, 1), # TODO - get correct start_date - provide_context=True, - op_kwargs=self.piece_input_kwargs, - # queue=dependencies_group, - make_python_callable_kwargs=dict( - piece_name=self.piece_name, - deploy_mode=self.deploy_mode, - task_id=self.task_id, - dag_id=self.dag_id, - ) - ) - - elif self.deploy_mode in ["local-k8s", "local-k8s-dev", "prod", "k8s"]: - # References: - # - https://airflow.apache.org/docs/apache-airflow/1.10.14/_api/airflow/contrib/operators/kubernetes_pod_operator/index.html - # - https://airflow.apache.org/docs/apache-airflow/stable/templates-ref.html - # - https://www.astronomer.io/guides/templating/ - # - good example: https://github.com/apache/airflow/blob/main/tests/system/providers/cncf/kubernetes/example_kubernetes.py - # - commands HAVE to go in a list object: https://stackoverflow.com/a/55149915/11483674 - - return DominoKubernetesPodOperator( - dag_id=self.dag_id, - task_id=self.task_id, - piece_name=self.piece.get('name'), - deploy_mode=self.deploy_mode, - repository_url=self.repository_url, - repository_version=self.repository_version, - workspace_id=self.workspace_id, - piece_input_kwargs=self.piece_input_kwargs, - workflow_shared_storage=self.workflow_shared_storage, - container_resources=self.container_resources, - # ----------------- Kubernetes ----------------- - namespace='default', - image=self.piece.get("source_image"), - image_pull_policy='IfNotPresent', - name=f"airflow-worker-pod-{self.task_id}", - startup_timeout_seconds=600, - annotations={"sidecar.istio.io/inject": "false"}, # TODO - remove this when istio is working with airflow k8s pod - # cmds=["/bin/bash"], - # arguments=["-c", "sleep 120;"], - cmds=["domino"], - arguments=["run-piece-k8s"], - do_xcom_push=True, - in_cluster=True, - ) - - elif self.deploy_mode == 'local-compose': - return DominoDockerOperator( + if not init_operator: + self._task_operator = None + else: + self._task_operator = DominoBaseOperator( dag_id=self.dag_id, task_id=self.task_id, - piece_name=self.piece.get('name'), + piece_name=self.piece.get("name"), deploy_mode=self.deploy_mode, repository_url=self.repository_url, repository_version=self.repository_version, @@ -142,16 +66,7 @@ def _set_operator(self) -> BaseOperator: piece_input_kwargs=self.piece_input_kwargs, workflow_shared_storage=self.workflow_shared_storage, container_resources=self.container_resources, - # ----------------- Docker ----------------- - # TODO uncoment - image=self.piece["source_image"], - entrypoint=["domino", "run-piece-docker"], - do_xcom_push=True, - mount_tmp_dir=False, - tty=True, - xcom_all=False, - retrieve_output=True, - retrieve_output_path='/airflow/xcom/return.out', + piece_source_image=self.piece.get("source_image"), ) def __call__(self) -> Callable: