Skip to content

Commit

Permalink
re-org & init Interceptor
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Apr 17, 2024
1 parent e06f35e commit d704fcb
Show file tree
Hide file tree
Showing 13 changed files with 46 additions and 18 deletions.
8 changes: 8 additions & 0 deletions flyrs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions flyrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ tokio = { version = "1.37.0", features = ["full"] }
pyo3 = { version = "0.21", features = ["extension-module", "experimental-async"] }

flyteidl = { path="../../flyte/flyteidl" }
tower = "0.4.13"

[build-dependencies]

2 changes: 1 addition & 1 deletion flyrs/remote/backfill.py → flyrs/clients/backfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from flytekit import LaunchPlan
from flytekit.core.workflow import ImperativeWorkflow, WorkflowBase, WorkflowFailurePolicy
from remote.entities import FlyteLaunchPlan
from entities import FlyteLaunchPlan


def create_backfill_workflow(
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions flyrs/remote/entities.py → flyrs/clients/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from flytekit.models.interface import TypedInterface
from flytekit.models.literals import Binding
from flytekit.models.task import TaskSpec
import remote.interface as _interfaces
from remote.remote_callable import RemoteEntity
import interface as _interfaces
from remote_callable import RemoteEntity


class FlyteTask(hash_mixin.HashOnReferenceMixin, RemoteEntity, TaskSpec):
Expand Down
2 changes: 1 addition & 1 deletion flyrs/remote/executions.py → flyrs/clients/executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flytekit.models import node_execution as node_execution_models
from flytekit.models.admin import task_execution as admin_task_execution_models
from flytekit.models.core import execution as core_execution_models
from remote.entities import FlyteTask, FlyteWorkflow
from entities import FlyteTask, FlyteWorkflow


class RemoteExecutionBase(object):
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from threading import Lock

from flytekit import FlyteContext
from remote.remote_callable import RemoteEntity
from remote_callable import RemoteEntity

T = typing.TypeVar("T", bound=RemoteEntity)

Expand Down
File renamed without changes.
File renamed without changes.
16 changes: 8 additions & 8 deletions flyrs/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@
)
from flytekit.models.launch_plan import LaunchPlanState
from flytekit.models.literals import Literal, LiteralMap
from remote.backfill import create_backfill_workflow
from remote.data import download_literal
from remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow
from remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution
from remote.interface import TypedInterface
from remote.lazy_entity import LazyEntity
from remote.remote_callable import RemoteEntity
from remote.remote_fs import get_flyte_fs
from clients.backfill import create_backfill_workflow
from clients.data import download_literal
from clients.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow
from clients.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution
from clients.interface import TypedInterface
from clients.lazy_entity import LazyEntity
from clients.remote_callable import RemoteEntity
from clients.remote_fs import get_flyte_fs
from flytekit.tools.fast_registration import fast_package
from flytekit.tools.interactive import ipython_check
from flytekit.tools.script_mode import _find_project_root, compress_scripts, hash_file
Expand Down
29 changes: 24 additions & 5 deletions flyrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@ use prost::{Message};
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use tokio::runtime::{Builder, Runtime};
use tonic::transport::Channel;
use tonic::{
metadata::MetadataValue,
codegen::InterceptedService,
service::Interceptor,
transport::{Channel, Endpoint, Error},
Request, Status,
};

use flyteidl::flyteidl::service::admin_service_client::AdminServiceClient;
use flyteidl::flyteidl::admin;//::{Task, ObjectGetRequest, ResourceListRequest, NamedEntityIdentifierListRequest, TaskExecutionGetRequest};
use flyteidl::flyteidl::admin;
use std::option::Option;

// Unlike the normal use case of PyO3, we don't have to add attribute macros such as #[pyclass] or #[pymethods] to all of our flyteidl structs.
// In this case, we only use PyO3 to expose the client class and its methods to Python (FlyteKit).
Expand All @@ -16,10 +23,20 @@ use flyteidl::flyteidl::admin;//::{Task, ObjectGetRequest, ResourceListRequest,

#[pyclass(subclass)]
pub struct FlyteClient {
admin_service: AdminServiceClient<Channel>,
admin_service: AdminServiceClient<InterceptedService<Channel, AuthUnaryInterceptor>>,
runtime: Runtime,
}

struct AuthUnaryInterceptor;

impl Interceptor for AuthUnaryInterceptor {
fn call(&mut self, mut request: tonic::Request<()>) -> Result<tonic::Request<()>, Status> {
let token: MetadataValue<_> = "Bearer some-auth-token".parse().unwrap();
request.metadata_mut().insert("authorization", token.clone());
Ok(request)
}
}

// Using temporary value(e.g., endpoint) in async is tricky w.r.t lifetime.
// The compiler will complain that the temporary value does not live long enough.
// TODO: figure out how to pass in the required initial args into constructor in a clean and neat way.
Expand All @@ -29,10 +46,12 @@ impl FlyteClient {
pub fn new() -> PyResult<FlyteClient> {
let rt = Builder::new_multi_thread().enable_all().build().unwrap();
// TODO: Create a channel then bind it to every stubs/clients instead of connecting everytime.
let stub = rt.block_on(AdminServiceClient::connect("http://localhost:30080")).unwrap();
let channel = rt.block_on(Endpoint::from_static("http://localhost:30080").connect()).unwrap();
// let stub = rt.block_on(AdminServiceClient::connect("http://localhost:30080")).unwrap();
let mut stub = AdminServiceClient::with_interceptor(channel, AuthUnaryInterceptor);
// TODO: Add more thoughtful error handling
Ok(FlyteClient {
runtime: rt, // The tokio runtime is used in a blocking manner now, left lots of investigation and TODOs behind.
runtime: rt, // The tokio runtime is used in a blocking manner now, leaving lots of investigation and TODOs behind.
admin_service: stub,
}
)
Expand Down

0 comments on commit d704fcb

Please sign in to comment.