Skip to content

Commit

Permalink
chore(main): new direct service initialization (#199)
Browse files Browse the repository at this point in the history
related changes for documentation at
https://github.com/instill-ai/instill.tech/pull/1087

Because

- users should be able to initialize service directly

This commit

- add new direct service initialization functions
  • Loading branch information
joremysh authored Sep 4, 2024
1 parent db5224e commit 3df80da
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 103 deletions.
20 changes: 18 additions & 2 deletions instill/clients/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


class ArtifactClient(Client):
def __init__(self, async_enabled: bool) -> None:
def __init__(self, async_enabled: bool = False, api_token: str = "") -> None:
self.hosts: Dict[str, InstillInstance] = {}
if DEFAULT_INSTANCE in global_config.hosts:
self.instance = DEFAULT_INSTANCE
Expand All @@ -30,14 +30,27 @@ def __init__(self, async_enabled: bool) -> None:

if global_config.hosts is not None:
for instance, config in global_config.hosts.items():
token = config.token
if api_token != "" and instance == self.instance:
token = api_token
self.hosts[instance] = InstillInstance(
artifact_service.ArtifactPublicServiceStub,
url=config.url,
token=config.token,
token=token,
secure=config.secure,
async_enabled=async_enabled,
)

def close(self):
if self.is_serving():
for host in self.hosts.values():
host.channel.close()

async def async_close(self):
if self.is_serving():
for host in self.hosts.values():
await host.async_channel.close()

@property
def hosts(self):
return self._hosts
Expand All @@ -54,6 +67,9 @@ def instance(self):
def instance(self, instance: str):
self._instance = instance

def set_instance(self, instance: str):
self._instance = instance

@property
def metadata(self):
return self._metadata
Expand Down
51 changes: 51 additions & 0 deletions instill/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,54 @@ async def async_close(self):

def get_client(async_enabled: bool = False) -> InstillClient:
return InstillClient(async_enabled=async_enabled)


def init_artifact_client(
api_token: str = "", async_enabled: bool = False
) -> ArtifactClient:
client = ArtifactClient(api_token=api_token, async_enabled=async_enabled)
if not client.is_serving():
Logger.w(
"Instill Artifact is not serving, Artifact functionalities will not work"
)
raise NotServingException

return client


def init_model_client(api_token: str = "", async_enabled: bool = False) -> ModelClient:
mgmt_service = MgmtClient(api_token=api_token, async_enabled=async_enabled)
if not mgmt_service.is_serving():
Logger.w("Instill Core is required")
raise NotServingException

user_id = mgmt_service.get_user().user.id

client = ModelClient(
namespace=user_id, api_token=api_token, async_enabled=async_enabled
)
if not client.is_serving():
Logger.w("Instill Model is not serving, Model functionalities will not work")
raise NotServingException

return client


def init_pipeline_client(
api_token: str = "", async_enabled: bool = False
) -> PipelineClient:
mgmt_service = MgmtClient(api_token=api_token, async_enabled=async_enabled)
if not mgmt_service.is_serving():
Logger.w("Instill Core is required")
raise NotServingException

user_id = mgmt_service.get_user().user.id

client = PipelineClient(
namespace=user_id, api_token=api_token, async_enabled=async_enabled
)
if not client.is_serving():
Logger.w("Instill VDP is not serving, VDP functionalities will not work")
raise NotServingException

return client
12 changes: 8 additions & 4 deletions instill/clients/mgmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
from instill.configuration import global_config
from instill.utils.error_handler import grpc_handler

# from instill.utils.logger import Logger


class MgmtClient(Client):
def __init__(self, async_enabled: bool) -> None:
def __init__(self, async_enabled: bool = False, api_token: str = "") -> None:
self.hosts: Dict[str, InstillInstance] = {}
if DEFAULT_INSTANCE in global_config.hosts:
self.instance = DEFAULT_INSTANCE
Expand All @@ -31,10 +29,13 @@ def __init__(self, async_enabled: bool) -> None:

if global_config.hosts is not None:
for instance, config in global_config.hosts.items():
token = config.token
if api_token != "" and instance == self.instance:
token = api_token
self.hosts[instance] = InstillInstance(
mgmt_service.MgmtPublicServiceStub,
url=config.url,
token=config.token,
token=token,
secure=config.secure,
async_enabled=async_enabled,
)
Expand All @@ -55,6 +56,9 @@ def instance(self):
def instance(self, instance: str):
self._instance = instance

def set_instance(self, instance: str):
self._instance = instance

@property
def metadata(self):
return self._metadata
Expand Down
26 changes: 22 additions & 4 deletions instill/clients/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


class ModelClient(Client):
def __init__(self, namespace: str, async_enabled: bool) -> None:
def __init__(
self, namespace: str, async_enabled: bool = False, api_token: str = ""
) -> None:
self.hosts: Dict[str, InstillInstance] = {}
self.namespace: str = namespace
if DEFAULT_INSTANCE in global_config.hosts:
Expand All @@ -31,14 +33,27 @@ def __init__(self, namespace: str, async_enabled: bool) -> None:

if global_config.hosts is not None:
for instance, config in global_config.hosts.items():
token = config.token
if api_token != "" and instance == self.instance:
token = api_token
self.hosts[instance] = InstillInstance(
model_service.ModelPublicServiceStub,
url=config.url,
token=config.token,
token=token,
secure=config.secure,
async_enabled=async_enabled,
)

def close(self):
if self.is_serving():
for host in self.hosts.values():
host.channel.close()

async def async_close(self):
if self.is_serving():
for host in self.hosts.values():
await host.async_channel.close()

@property
def hosts(self):
return self._hosts
Expand All @@ -55,6 +70,9 @@ def instance(self):
def instance(self, instance: str):
self._instance = instance

def set_instance(self, instance: str):
self._instance = instance

@property
def metadata(self):
return self._metadata
Expand Down Expand Up @@ -200,7 +218,7 @@ def trigger_model(
return RequestFactory(
method=self.hosts[self.instance].async_client.TriggerUserModel,
request=model_interface.TriggerUserModelRequest(
name=f"{self.namespace}/models/{model_name}",
name=f"namespaces/{self.namespace}/models/{model_name}",
task_inputs=task_inputs,
version=version,
),
Expand All @@ -210,7 +228,7 @@ def trigger_model(
return RequestFactory(
method=self.hosts[self.instance].client.TriggerUserModel,
request=model_interface.TriggerUserModelRequest(
name=f"{self.namespace}/models/{model_name}",
name=f"namespaces/{self.namespace}/models/{model_name}",
task_inputs=task_inputs,
version=version,
),
Expand Down
Loading

0 comments on commit 3df80da

Please sign in to comment.