Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Configs #218

Merged
merged 25 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5c29257
refactoring configs
mrwyattii Aug 4, 2023
206fc2c
further refactoring
mrwyattii Aug 8, 2023
926795d
fix errors, get things running
mrwyattii Aug 8, 2023
0b04ed2
formatting and flake fixes
mrwyattii Aug 8, 2023
e5aea20
remove formatting changes
mrwyattii Aug 8, 2023
b0f3f65
get tests working again
mrwyattii Aug 8, 2023
7e206ef
fix non-persistent deployment error
mrwyattii Aug 8, 2023
018215e
fix restful test
mrwyattii Aug 8, 2023
c77c78d
fix zero config test failures
mrwyattii Aug 9, 2023
3f50926
attempt to fix hf_auth error
mrwyattii Aug 9, 2023
b9ab6ed
Merge branch 'main' into mrwyattii/refactor-config
mrwyattii Aug 29, 2023
ff25493
fix mistakes from merge with main branch
mrwyattii Aug 29, 2023
b58a93d
fixes for unit tests
mrwyattii Aug 31, 2023
a5333fb
Merge branch 'main' into mrwyattii/refactor-config
mrwyattii Aug 31, 2023
11f5f9f
skip re-validating configs when we launch sub-processes
mrwyattii Sep 1, 2023
572231e
resolve some remaining bugs
mrwyattii Sep 12, 2023
cfbb6b9
move erroneously added exit()
mrwyattii Sep 12, 2023
e42a275
Merge branch 'main' into mrwyattii/refactor-config
mrwyattii Sep 12, 2023
2253686
fix error from merge conflict resolution
mrwyattii Sep 14, 2023
3eab7c3
resolved comments on PR
mrwyattii Sep 19, 2023
339c1fe
additional QoL improvements to inference pipeline kwarg passing
mrwyattii Sep 19, 2023
8bcc061
fixes
mrwyattii Sep 20, 2023
3e1e8af
fix type hint
mrwyattii Sep 20, 2023
f3d970d
update varible names in tests
mrwyattii Sep 20, 2023
09b52e3
fix wrong arg parsing name
mrwyattii Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mii/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from .client import MIIClient, mii_query_handle
from .deployment import deploy
from .terminate import terminate
from .constants import DeploymentType, Tasks
from .constants import DeploymentType, TaskType
from .aml_related.utils import aml_output_path

from .config import MIIConfig, LoadBalancerConfig
from .config import MIIConfig, DeploymentConfig
from .grpc_related.proto import modelresponse_pb2_grpc

__version__ = "0.0.0"
Expand Down
81 changes: 44 additions & 37 deletions mii/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,15 @@
import grpc
import requests
import mii
from mii.utils import get_task
from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
from mii.constants import GRPC_MAX_MSG_SIZE, Tasks
from mii.constants import GRPC_MAX_MSG_SIZE, TaskType
from mii.method_table import GRPC_METHOD_TABLE
from mii.config import MIIConfig


def _get_deployment_info(deployment_name):
configs = mii.utils.import_score_file(deployment_name).configs
task = configs[mii.constants.TASK_NAME_KEY]
mii_configs_dict = configs[mii.constants.MII_CONFIGS_KEY]
mii_configs = mii.config.MIIConfig(**mii_configs_dict)

assert task is not None, "The task name should be set before calling init"
return task, mii_configs
def _get_mii_config(deployment_name):
mii_config = mii.utils.import_score_file(deployment_name).mii_config
return MIIConfig(**mii_config)


def mii_query_handle(deployment_name):
Expand All @@ -39,27 +34,33 @@ def mii_query_handle(deployment_name):
inference_pipeline, task = mii.non_persistent_models[deployment_name]
return MIINonPersistentClient(task, deployment_name)

task_name, mii_configs = _get_deployment_info(deployment_name)
return MIIClient(task_name, "localhost", mii_configs.port_number)
mii_config = _get_mii_config(deployment_name)
return MIIClient(mii_config.deployment_config.task,
"localhost",
mii_config.port_number)


def create_channel(host, port):
return grpc.aio.insecure_channel(f'{host}:{port}',
options=[('grpc.max_send_message_length',
GRPC_MAX_MSG_SIZE),
('grpc.max_receive_message_length',
GRPC_MAX_MSG_SIZE)])


class MIIClient():
return grpc.aio.insecure_channel(
f"{host}:{port}",
options=[
("grpc.max_send_message_length",
GRPC_MAX_MSG_SIZE),
("grpc.max_receive_message_length",
GRPC_MAX_MSG_SIZE),
],
)


class MIIClient:
"""
Client to send queries to a single endpoint.
"""
def __init__(self, task_name, host, port):
def __init__(self, task, host, port):
self.asyncio_loop = asyncio.get_event_loop()
channel = create_channel(host, port)
self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
self.task = get_task(task_name)
self.task = task

async def _request_async_response(self, request_dict, **query_kwargs):
if self.task not in GRPC_METHOD_TABLE:
Expand Down Expand Up @@ -87,7 +88,9 @@ async def create_session_async(self, session_id):
modelresponse_pb2.SessionID(session_id=session_id))

def create_session(self, session_id):
assert self.task == Tasks.TEXT_GENERATION, f"Session creation only available for task '{Tasks.TEXT_GENERATION}'."
assert (
self.task == TaskType.TEXT_GENERATION
), f"Session creation only available for task '{TaskType.TEXT_GENERATION}'."
return self.asyncio_loop.run_until_complete(
self.create_session_async(session_id))

Expand All @@ -96,18 +99,20 @@ async def destroy_session_async(self, session_id):
)

def destroy_session(self, session_id):
assert self.task == Tasks.TEXT_GENERATION, f"Session deletion only available for task '{Tasks.TEXT_GENERATION}'."
assert (
self.task == TaskType.TEXT_GENERATION
), f"Session deletion only available for task '{TaskType.TEXT_GENERATION}'."
self.asyncio_loop.run_until_complete(self.destroy_session_async(session_id))


class MIITensorParallelClient():
class MIITensorParallelClient:
"""
Client to send queries to multiple endpoints in parallel.
This is used to call multiple servers deployed for tensor parallelism.
"""
def __init__(self, task_name, host, ports):
self.task = get_task(task_name)
self.clients = [MIIClient(task_name, host, port) for port in ports]
def __init__(self, task, host, ports):
self.task = task
self.clients = [MIIClient(task, host, port) for port in ports]
self.asyncio_loop = asyncio.get_event_loop()

# runs task in parallel and return the result from the first task
Expand Down Expand Up @@ -155,30 +160,32 @@ def destroy_session(self, session_id):
client.destroy_session(session_id)


class MIINonPersistentClient():
class MIINonPersistentClient:
def __init__(self, task, deployment_name):
self.task = task
self.deployment_name = deployment_name

def query(self, request_dict, **query_kwargs):
assert self.deployment_name in mii.non_persistent_models, f"deployment: {self.deployment_name} not found"
assert (
self.deployment_name in mii.non_persistent_models
), f"deployment: {self.deployment_name} not found"
task_methods = GRPC_METHOD_TABLE[self.task]
inference_pipeline = mii.non_persistent_models[self.deployment_name][0]

if self.task == Tasks.QUESTION_ANSWERING:
if 'question' not in request_dict or 'context' not in request_dict:
if self.task == TaskType.QUESTION_ANSWERING:
if "question" not in request_dict or "context" not in request_dict:
raise Exception(
"Question Answering Task requires 'question' and 'context' keys")
args = (request_dict["question"], request_dict["context"])
kwargs = query_kwargs

elif self.task == Tasks.CONVERSATIONAL:
elif self.task == TaskType.CONVERSATIONAL:
conv = task_methods.create_conversation(request_dict, **query_kwargs)
args = (conv, )
kwargs = {}

else:
args = (request_dict['query'], )
args = (request_dict["query"], )
kwargs = query_kwargs

return task_methods.run_inference(inference_pipeline, args, query_kwargs)
Expand All @@ -189,6 +196,6 @@ def terminate(self):


def terminate_restful_gateway(deployment_name):
_, mii_configs = _get_deployment_info(deployment_name)
if mii_configs.enable_restful_api:
requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate")
mii_config = _get_mii_config(deployment_name)
if mii_config.enable_restful_api:
requests.get(f"http://localhost:{mii_config.restful_api_port}/terminate")
Loading