From 3df80da471334da1cf5c99bec779ee649922141b Mon Sep 17 00:00:00 2001 From: Jeremy Shih Date: Wed, 4 Sep 2024 11:58:51 +0800 Subject: [PATCH] chore(main): new direct service initialization (#199) related changes for documentation at https://github.com/instill-ai/instill.tech/pull/1087 Because - users should be able to initialize service directly This commit - add new direct service initialization functions --- instill/clients/artifact.py | 20 +++- instill/clients/client.py | 51 +++++++++ instill/clients/mgmt.py | 12 ++- instill/clients/model.py | 26 ++++- instill/clients/pipeline.py | 170 +++++++++++++++--------------- instill/configuration/__init__.py | 5 +- instill/tests/test_client.py | 7 +- tests/test_client.py | 7 +- 8 files changed, 195 insertions(+), 103 deletions(-) diff --git a/instill/clients/artifact.py b/instill/clients/artifact.py index 6fbad4dc..fee813d8 100644 --- a/instill/clients/artifact.py +++ b/instill/clients/artifact.py @@ -19,7 +19,7 @@ class ArtifactClient(Client): - def __init__(self, async_enabled: bool) -> None: + def __init__(self, async_enabled: bool = False, api_token: str = "") -> None: self.hosts: Dict[str, InstillInstance] = {} if DEFAULT_INSTANCE in global_config.hosts: self.instance = DEFAULT_INSTANCE @@ -30,14 +30,27 @@ def __init__(self, async_enabled: bool) -> None: if global_config.hosts is not None: for instance, config in global_config.hosts.items(): + token = config.token + if api_token != "" and instance == self.instance: + token = api_token self.hosts[instance] = InstillInstance( artifact_service.ArtifactPublicServiceStub, url=config.url, - token=config.token, + token=token, secure=config.secure, async_enabled=async_enabled, ) + def close(self): + if self.is_serving(): + for host in self.hosts.values(): + host.channel.close() + + async def async_close(self): + if self.is_serving(): + for host in self.hosts.values(): + await host.async_channel.close() + @property def hosts(self): return self._hosts @@ -54,6 +67,9 @@ def instance(self): def instance(self, instance: str): self._instance = instance + def set_instance(self, instance: str): + self._instance = instance + @property def metadata(self): return self._metadata diff --git a/instill/clients/client.py b/instill/clients/client.py index 9870663e..4bdebec1 100644 --- a/instill/clients/client.py +++ b/instill/clients/client.py @@ -75,3 +75,54 @@ async def async_close(self): def get_client(async_enabled: bool = False) -> InstillClient: return InstillClient(async_enabled=async_enabled) + + +def init_artifact_client( + api_token: str = "", async_enabled: bool = False +) -> ArtifactClient: + client = ArtifactClient(api_token=api_token, async_enabled=async_enabled) + if not client.is_serving(): + Logger.w( + "Instill Artifact is not serving, Artifact functionalities will not work" + ) + raise NotServingException + + return client + + +def init_model_client(api_token: str = "", async_enabled: bool = False) -> ModelClient: + mgmt_service = MgmtClient(api_token=api_token, async_enabled=async_enabled) + if not mgmt_service.is_serving(): + Logger.w("Instill Core is required") + raise NotServingException + + user_id = mgmt_service.get_user().user.id + + client = ModelClient( + namespace=user_id, api_token=api_token, async_enabled=async_enabled + ) + if not client.is_serving(): + Logger.w("Instill Model is not serving, Model functionalities will not work") + raise NotServingException + + return client + + +def init_pipeline_client( + api_token: str = "", async_enabled: bool = False +) -> PipelineClient: + mgmt_service = MgmtClient(api_token=api_token, async_enabled=async_enabled) + if not mgmt_service.is_serving(): + Logger.w("Instill Core is required") + raise NotServingException + + user_id = mgmt_service.get_user().user.id + + client = PipelineClient( + namespace=user_id, api_token=api_token, async_enabled=async_enabled + ) + if not client.is_serving(): + Logger.w("Instill VDP is not serving, VDP functionalities will not work") + raise NotServingException + + return client diff --git a/instill/clients/mgmt.py b/instill/clients/mgmt.py index c20b300a..1987a941 100644 --- a/instill/clients/mgmt.py +++ b/instill/clients/mgmt.py @@ -16,11 +16,9 @@ from instill.configuration import global_config from instill.utils.error_handler import grpc_handler -# from instill.utils.logger import Logger - class MgmtClient(Client): - def __init__(self, async_enabled: bool) -> None: + def __init__(self, async_enabled: bool = False, api_token: str = "") -> None: self.hosts: Dict[str, InstillInstance] = {} if DEFAULT_INSTANCE in global_config.hosts: self.instance = DEFAULT_INSTANCE @@ -31,10 +29,13 @@ def __init__(self, async_enabled: bool) -> None: if global_config.hosts is not None: for instance, config in global_config.hosts.items(): + token = config.token + if api_token != "" and instance == self.instance: + token = api_token self.hosts[instance] = InstillInstance( mgmt_service.MgmtPublicServiceStub, url=config.url, - token=config.token, + token=token, secure=config.secure, async_enabled=async_enabled, ) @@ -55,6 +56,9 @@ def instance(self): def instance(self, instance: str): self._instance = instance + def set_instance(self, instance: str): + self._instance = instance + @property def metadata(self): return self._metadata diff --git a/instill/clients/model.py b/instill/clients/model.py index 9d7b76d0..48d0fb4a 100644 --- a/instill/clients/model.py +++ b/instill/clients/model.py @@ -19,7 +19,9 @@ class ModelClient(Client): - def __init__(self, namespace: str, async_enabled: bool) -> None: + def __init__( + self, namespace: str, async_enabled: bool = False, api_token: str = "" + ) -> None: self.hosts: Dict[str, InstillInstance] = {} self.namespace: str = namespace if DEFAULT_INSTANCE in global_config.hosts: @@ -31,14 +33,27 @@ def __init__(self, namespace: str, async_enabled: bool) -> None: if global_config.hosts is not None: for instance, config in global_config.hosts.items(): + token = config.token + if api_token != "" and instance == self.instance: + token = api_token self.hosts[instance] = InstillInstance( model_service.ModelPublicServiceStub, url=config.url, - token=config.token, + token=token, secure=config.secure, async_enabled=async_enabled, ) + def close(self): + if self.is_serving(): + for host in self.hosts.values(): + host.channel.close() + + async def async_close(self): + if self.is_serving(): + for host in self.hosts.values(): + await host.async_channel.close() + @property def hosts(self): return self._hosts @@ -55,6 +70,9 @@ def instance(self): def instance(self, instance: str): self._instance = instance + def set_instance(self, instance: str): + self._instance = instance + @property def metadata(self): return self._metadata @@ -200,7 +218,7 @@ def trigger_model( return RequestFactory( method=self.hosts[self.instance].async_client.TriggerUserModel, request=model_interface.TriggerUserModelRequest( - name=f"{self.namespace}/models/{model_name}", + name=f"namespaces/{self.namespace}/models/{model_name}", task_inputs=task_inputs, version=version, ), @@ -210,7 +228,7 @@ def trigger_model( return RequestFactory( method=self.hosts[self.instance].client.TriggerUserModel, request=model_interface.TriggerUserModelRequest( - name=f"{self.namespace}/models/{model_name}", + name=f"namespaces/{self.namespace}/models/{model_name}", task_inputs=task_inputs, version=version, ), diff --git a/instill/clients/pipeline.py b/instill/clients/pipeline.py index b2cf4aca..cdc77bc5 100644 --- a/instill/clients/pipeline.py +++ b/instill/clients/pipeline.py @@ -23,7 +23,11 @@ class PipelineClient(Client): def __init__( - self, namespace: str, async_enabled: bool, target_namespace: str = "" + self, + namespace: str, + async_enabled: bool = False, + target_namespace: str = "", + api_token: str = "", ) -> None: self.hosts: Dict[str, InstillInstance] = {} self.namespace: str = namespace @@ -39,14 +43,27 @@ def __init__( if global_config.hosts is not None: for instance, config in global_config.hosts.items(): + token = config.token + if api_token != "" and instance == self.instance: + token = api_token self.hosts[instance] = InstillInstance( pipeline_service.PipelinePublicServiceStub, url=config.url, - token=config.token, + token=token, secure=config.secure, async_enabled=async_enabled, ) + def close(self): + if self.is_serving(): + for host in self.hosts.values(): + host.channel.close() + + async def async_close(self): + if self.is_serving(): + for host in self.hosts.values(): + await host.async_channel.close() + @property def hosts(self): return self._hosts @@ -63,6 +80,9 @@ def instance(self): def instance(self, instance: str): self._instance = instance + def set_instance(self, instance: str): + self._instance = instance + @property def metadata(self): return self._metadata @@ -303,17 +323,12 @@ def clone_pipeline( def trigger_pipeline( self, name: str, - inputs: list, data: list, async_enabled: bool = False, ) -> pipeline_interface.TriggerUserPipelineResponse: request = pipeline_interface.TriggerUserPipelineRequest( name=f"{self.target_namespace}/pipelines/{name}", ) - for input_value in inputs: - trigger_inputs = Struct() - trigger_inputs.update(input_value) - request.inputs.append(trigger_inputs) for d in data: trigger_data = pipeline_interface.TriggerData() trigger_data.variable.update(d) @@ -336,17 +351,12 @@ def trigger_pipeline( def trigger_async_pipeline( self, name: str, - inputs: list, data: list, async_enabled: bool = False, ) -> pipeline_interface.TriggerAsyncUserPipelineResponse: request = pipeline_interface.TriggerAsyncUserPipelineRequest( name=f"{self.target_namespace}/pipelines/{name}", ) - for input_value in inputs: - trigger_inputs = Struct() - trigger_inputs.update(input_value) - request.inputs.append(trigger_inputs) for d in data: trigger_data = pipeline_interface.TriggerData() trigger_data.variable.update(d) @@ -450,30 +460,28 @@ def trigger_namespace_pipeline( self, namespace_id: str, pipeline_id: str, - inputs: list[Struct], - data: list[pipeline_interface.TriggerData], + data: list, async_enabled: bool = False, ) -> pipeline_interface.TriggerNamespacePipelineResponse: + request = pipeline_interface.TriggerNamespacePipelineRequest( + namespace_id=namespace_id, + pipeline_id=pipeline_id, + ) + for d in data: + trigger_data = pipeline_interface.TriggerData() + trigger_data.variable.update(d) + request.data.append(trigger_data) + if async_enabled: return RequestFactory( method=self.hosts[self.instance].async_client.TriggerNamespacePipeline, - request=pipeline_interface.TriggerNamespacePipelineRequest( - namespace_id=namespace_id, - pipeline_id=pipeline_id, - inputs=inputs, - data=data, - ), + request=request, metadata=self.hosts[self.instance].metadata, ).send_async() return RequestFactory( method=self.hosts[self.instance].client.TriggerNamespacePipeline, - request=pipeline_interface.TriggerNamespacePipelineRequest( - namespace_id=namespace_id, - pipeline_id=pipeline_id, - inputs=inputs, - data=data, - ), + request=request, metadata=self.hosts[self.instance].metadata, ).send_sync() @@ -482,32 +490,30 @@ def trigger_namespace_pipeline_with_stream( self, namespace_id: str, pipeline_id: str, - inputs: list[Struct], - data: list[pipeline_interface.TriggerData], + data: list, async_enabled: bool = False, ) -> pipeline_interface.TriggerNamespacePipelineWithStreamResponse: + request = pipeline_interface.TriggerNamespacePipelineWithStreamRequest( + namespace_id=namespace_id, + pipeline_id=pipeline_id, + ) + for d in data: + trigger_data = pipeline_interface.TriggerData() + trigger_data.variable.update(d) + request.data.append(trigger_data) + if async_enabled: return RequestFactory( method=self.hosts[ self.instance ].async_client.TriggerNamespacePipelineWithStream, - request=pipeline_interface.TriggerNamespacePipelineWithStreamRequest( - namespace_id=namespace_id, - pipeline_id=pipeline_id, - inputs=inputs, - data=data, - ), + request=request, metadata=self.hosts[self.instance].metadata, ).send_async() return RequestFactory( method=self.hosts[self.instance].client.TriggerNamespacePipelineWithStream, - request=pipeline_interface.TriggerNamespacePipelineWithStreamRequest( - namespace_id=namespace_id, - pipeline_id=pipeline_id, - inputs=inputs, - data=data, - ), + request=request, metadata=self.hosts[self.instance].metadata, ).send_sync() @@ -516,32 +522,30 @@ def trigger_async_namespace_pipeline( self, namespace_id: str, pipeline_id: str, - inputs: list[Struct], - data: list[pipeline_interface.TriggerData], + data: list, async_enabled: bool = False, ) -> pipeline_interface.TriggerAsyncNamespacePipelineResponse: + request = pipeline_interface.TriggerAsyncNamespacePipelineRequest( + namespace_id=namespace_id, + pipeline_id=pipeline_id, + ) + for d in data: + trigger_data = pipeline_interface.TriggerData() + trigger_data.variable.update(d) + request.data.append(trigger_data) + if async_enabled: return RequestFactory( method=self.hosts[ self.instance ].async_client.TriggerAsyncNamespacePipeline, - request=pipeline_interface.TriggerAsyncNamespacePipelineRequest( - namespace_id=namespace_id, - pipeline_id=pipeline_id, - inputs=inputs, - data=data, - ), + request=request, metadata=self.hosts[self.instance].metadata, ).send_async() return RequestFactory( method=self.hosts[self.instance].client.TriggerAsyncNamespacePipeline, - request=pipeline_interface.TriggerAsyncNamespacePipelineRequest( - namespace_id=namespace_id, - pipeline_id=pipeline_id, - inputs=inputs, - data=data, - ), + request=request, metadata=self.hosts[self.instance].metadata, ).send_sync() @@ -765,34 +769,31 @@ def trigger_namespace_pipeline_release( namespace_id: str, pipeline_id: str, release_id: str, - inputs: list[Struct], - data: list[pipeline_interface.TriggerData], + data: list, async_enabled: bool = False, ) -> pipeline_interface.TriggerNamespacePipelineReleaseResponse: + request = pipeline_interface.TriggerNamespacePipelineReleaseRequest( + namespace_id=namespace_id, + pipeline_id=pipeline_id, + release_id=release_id, + ) + for d in data: + trigger_data = pipeline_interface.TriggerData() + trigger_data.variable.update(d) + request.data.append(trigger_data) + if async_enabled: return RequestFactory( method=self.hosts[ self.instance ].async_client.TriggerNamespacePipelineRelease, - request=pipeline_interface.TriggerNamespacePipelineReleaseRequest( - namespace_id=namespace_id, - pipeline_id=pipeline_id, - release_id=release_id, - inputs=inputs, - data=data, - ), + request=request, metadata=self.hosts[self.instance].metadata, ).send_async() return RequestFactory( method=self.hosts[self.instance].client.TriggerNamespacePipelineRelease, - request=pipeline_interface.TriggerNamespacePipelineReleaseRequest( - namespace_id=namespace_id, - pipeline_id=pipeline_id, - release_id=release_id, - inputs=inputs, - data=data, - ), + request=request, metadata=self.hosts[self.instance].metadata, ).send_sync() @@ -802,22 +803,25 @@ def trigger_async_namespace_pipeline_release( namespace_id: str, pipeline_id: str, release_id: str, - inputs: list[Struct], - data: list[pipeline_interface.TriggerData], + data: list, async_enabled: bool = False, ) -> pipeline_interface.TriggerAsyncNamespacePipelineReleaseResponse: + request = pipeline_interface.TriggerAsyncNamespacePipelineReleaseRequest( + namespace_id=namespace_id, + pipeline_id=pipeline_id, + release_id=release_id, + ) + for d in data: + trigger_data = pipeline_interface.TriggerData() + trigger_data.variable.update(d) + request.data.append(trigger_data) + if async_enabled: return RequestFactory( method=self.hosts[ self.instance ].async_client.TriggerAsyncNamespacePipelineRelease, - request=pipeline_interface.TriggerAsyncNamespacePipelineReleaseRequest( - namespace_id=namespace_id, - pipeline_id=pipeline_id, - release_id=release_id, - inputs=inputs, - data=data, - ), + request=request, metadata=self.hosts[self.instance].metadata, ).send_async() @@ -825,13 +829,7 @@ def trigger_async_namespace_pipeline_release( method=self.hosts[ self.instance ].client.TriggerAsyncNamespacePipelineRelease, - request=pipeline_interface.TriggerAsyncNamespacePipelineReleaseRequest( - namespace_id=namespace_id, - pipeline_id=pipeline_id, - release_id=release_id, - inputs=inputs, - data=data, - ), + request=request, metadata=self.hosts[self.instance].metadata, ).send_sync() diff --git a/instill/configuration/__init__.py b/instill/configuration/__init__.py index ff2a6e91..71a7a273 100644 --- a/instill/configuration/__init__.py +++ b/instill/configuration/__init__.py @@ -28,6 +28,10 @@ class _Config(BaseModel): class Configuration: def __init__(self) -> None: self._config: _Config + self.load() + + if "default" not in self._config.hosts: + self.set_default(url="api.instill.tech", secure=True, token="") CONFIG_DIR.mkdir(exist_ok=True) @@ -68,4 +72,3 @@ def set_default(self, url: str, token: str, secure: bool): global_config = Configuration() -global_config.load() diff --git a/instill/tests/test_client.py b/instill/tests/test_client.py index e6b66398..cc51eed8 100644 --- a/instill/tests/test_client.py +++ b/instill/tests/test_client.py @@ -4,6 +4,7 @@ import instill.protogen.model.model.v1alpha.model_public_service_pb2_grpc as model_service import instill.protogen.vdp.pipeline.v1beta.pipeline_public_service_pb2_grpc as pipeline_service from instill.clients import MgmtClient, ModelClient, PipelineClient +from instill.clients.constant import DEFAULT_INSTANCE from instill.clients.instance import InstillInstance @@ -11,11 +12,11 @@ def describe_client(): def describe_instance(): def when_not_set(expect): mgmt_client = MgmtClient(False) - expect(mgmt_client.instance) == "" + expect(mgmt_client.instance) == DEFAULT_INSTANCE model_client = ModelClient(namespace="", async_enabled=False) - expect(model_client.instance) == "" + expect(model_client.instance) == DEFAULT_INSTANCE pipeline_client = PipelineClient(namespace="", async_enabled=False) - expect(pipeline_client.instance) == "" + expect(pipeline_client.instance) == DEFAULT_INSTANCE def when_set_correct_type(expect): mgmt_client = MgmtClient(False) diff --git a/tests/test_client.py b/tests/test_client.py index e6b66398..cc51eed8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,6 +4,7 @@ import instill.protogen.model.model.v1alpha.model_public_service_pb2_grpc as model_service import instill.protogen.vdp.pipeline.v1beta.pipeline_public_service_pb2_grpc as pipeline_service from instill.clients import MgmtClient, ModelClient, PipelineClient +from instill.clients.constant import DEFAULT_INSTANCE from instill.clients.instance import InstillInstance @@ -11,11 +12,11 @@ def describe_client(): def describe_instance(): def when_not_set(expect): mgmt_client = MgmtClient(False) - expect(mgmt_client.instance) == "" + expect(mgmt_client.instance) == DEFAULT_INSTANCE model_client = ModelClient(namespace="", async_enabled=False) - expect(model_client.instance) == "" + expect(model_client.instance) == DEFAULT_INSTANCE pipeline_client = PipelineClient(namespace="", async_enabled=False) - expect(pipeline_client.instance) == "" + expect(pipeline_client.instance) == DEFAULT_INSTANCE def when_set_correct_type(expect): mgmt_client = MgmtClient(False)