-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(model):add model web management (#613)
- Loading branch information
Showing
43 changed files
with
373 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,5 @@ toml | |
myst_nb | ||
sphinx_copybutton | ||
pydata-sphinx-theme==0.13.1 | ||
pydantic-settings | ||
furo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from typing import Optional, Any | ||
from pyspark.sql import SparkSession, DataFrame | ||
from sqlalchemy import text | ||
|
||
from pilot.connections.base import BaseConnect | ||
|
||
|
||
class SparkConnect(BaseConnect): | ||
"""Spark Connect | ||
Args: | ||
Usage: | ||
""" | ||
|
||
"""db type""" | ||
db_type: str = "spark" | ||
"""db driver""" | ||
driver: str = "spark" | ||
"""db dialect""" | ||
dialect: str = "sparksql" | ||
|
||
def __init__( | ||
self, | ||
file_path: str, | ||
spark_session: Optional[SparkSession] = None, | ||
engine_args: Optional[dict] = None, | ||
**kwargs: Any, | ||
) -> None: | ||
"""Initialize the Spark DataFrame from Datasource path | ||
return: Spark DataFrame | ||
""" | ||
self.spark_session = ( | ||
spark_session or SparkSession.builder.appName("dbgpt").getOrCreate() | ||
) | ||
self.path = file_path | ||
self.table_name = "temp" | ||
self.df = self.create_df(self.path) | ||
|
||
@classmethod | ||
def from_file_path( | ||
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any | ||
): | ||
try: | ||
return cls(file_path=file_path, engine_args=engine_args) | ||
|
||
except Exception as e: | ||
print("load spark datasource error" + str(e)) | ||
|
||
def create_df(self, path) -> DataFrame: | ||
"""Create a Spark DataFrame from Datasource path | ||
return: Spark DataFrame | ||
""" | ||
return self.spark_session.read.option("header", "true").csv(path) | ||
|
||
def run(self, sql): | ||
# self.log(f"llm ingestion sql query is :\n{sql}") | ||
# self.df = self.create_df(self.path) | ||
self.df.createOrReplaceTempView(self.table_name) | ||
df = self.spark_session.sql(sql) | ||
first_row = df.first() | ||
rows = [first_row.asDict().keys()] | ||
for row in df.collect(): | ||
rows.append(row) | ||
return rows | ||
|
||
def query_ex(self, sql): | ||
rows = self.run(sql) | ||
field_names = rows[0] | ||
return field_names, rows | ||
|
||
def get_indexes(self, table_name): | ||
"""Get table indexes about specified table.""" | ||
return "" | ||
|
||
def get_show_create_table(self, table_name): | ||
"""Get table show create table about specified table.""" | ||
|
||
return "ans" | ||
|
||
def get_fields(self): | ||
"""Get column meta about dataframe.""" | ||
return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes]) | ||
|
||
def get_users(self): | ||
return [] | ||
|
||
def get_grants(self): | ||
return [] | ||
|
||
def get_collation(self): | ||
"""Get collation.""" | ||
return "UTF-8" | ||
|
||
def get_charset(self): | ||
return "UTF-8" | ||
|
||
def get_db_list(self): | ||
return ["default"] | ||
|
||
def get_db_names(self): | ||
return ["default"] | ||
|
||
def get_database_list(self): | ||
return [] | ||
|
||
def get_database_names(self): | ||
return [] | ||
|
||
def table_simple_info(self): | ||
return f"{self.table_name}{self.get_fields()}" | ||
|
||
def get_table_comments(self, db_name): | ||
return "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from typing import List | ||
|
||
from fastapi import APIRouter | ||
|
||
from pilot.component import ComponentType | ||
from pilot.configs.config import Config | ||
|
||
from pilot.model.cluster import WorkerStartupRequest, WorkerManagerFactory | ||
from pilot.openapi.api_view_model import Result | ||
|
||
from pilot.server.llm_manage.request.request import ModelResponse | ||
|
||
CFG = Config() | ||
router = APIRouter() | ||
|
||
|
||
@router.get("/v1/worker/model/params") | ||
async def model_params(): | ||
print(f"/worker/model/params") | ||
try: | ||
from pilot.model.cluster import WorkerManagerFactory | ||
|
||
worker_manager = CFG.SYSTEM_APP.get_component( | ||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory | ||
).create() | ||
params = [] | ||
workers = await worker_manager.supported_models() | ||
for worker in workers: | ||
for model in worker.models: | ||
model_dict = model.__dict__ | ||
model_dict["host"] = worker.host | ||
model_dict["port"] = worker.port | ||
params.append(model_dict) | ||
return Result.succ(params) | ||
if not worker_instance: | ||
return Result.faild(code="E000X", msg=f"can not find worker manager") | ||
except Exception as e: | ||
return Result.faild(code="E000X", msg=f"model stop failed {e}") | ||
|
||
|
||
@router.get("/v1/worker/model/list") | ||
async def model_list(): | ||
print(f"/worker/model/list") | ||
try: | ||
from pilot.model.cluster.controller.controller import BaseModelController | ||
|
||
controller = CFG.SYSTEM_APP.get_component( | ||
ComponentType.MODEL_CONTROLLER, BaseModelController | ||
) | ||
responses = [] | ||
managers = await controller.get_all_instances( | ||
model_name="WorkerManager@service", healthy_only=True | ||
) | ||
manager_map = dict(map(lambda manager: (manager.host, manager), managers)) | ||
models = await controller.get_all_instances() | ||
for model in models: | ||
worker_name, worker_type = model.model_name.split("@") | ||
if worker_type == "llm" or worker_type == "text2vec": | ||
response = ModelResponse( | ||
model_name=worker_name, | ||
model_type=worker_type, | ||
host=model.host, | ||
port=model.port, | ||
healthy=model.healthy, | ||
check_healthy=model.check_healthy, | ||
last_heartbeat=model.last_heartbeat, | ||
prompt_template=model.prompt_template, | ||
) | ||
response.manager_host = model.host if manager_map[model.host] else None | ||
response.manager_port = ( | ||
manager_map[model.host].port if manager_map[model.host] else None | ||
) | ||
responses.append(response) | ||
return Result.succ(responses) | ||
|
||
except Exception as e: | ||
return Result.faild(code="E000X", msg=f"space list error {e}") | ||
|
||
|
||
@router.post("/v1/worker/model/stop") | ||
async def model_stop(request: WorkerStartupRequest): | ||
print(f"/v1/worker/model/stop:") | ||
try: | ||
from pilot.model.cluster.controller.controller import BaseModelController | ||
|
||
worker_manager = CFG.SYSTEM_APP.get_component( | ||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory | ||
).create() | ||
if not worker_manager: | ||
return Result.faild(code="E000X", msg=f"can not find worker manager") | ||
request.params = {} | ||
return Result.succ(await worker_manager.model_shutdown(request)) | ||
except Exception as e: | ||
return Result.faild(code="E000X", msg=f"model stop failed {e}") | ||
|
||
|
||
@router.post("/v1/worker/model/start") | ||
async def model_start(request: WorkerStartupRequest): | ||
print(f"/v1/worker/model/start:") | ||
try: | ||
worker_manager = CFG.SYSTEM_APP.get_component( | ||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory | ||
).create() | ||
if not worker_manager: | ||
return Result.faild(code="E000X", msg=f"can not find worker manager") | ||
return Result.succ(await worker_manager.model_startup(request)) | ||
except Exception as e: | ||
return Result.faild(code="E000X", msg=f"model start failed {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class ModelResponse: | ||
"""ModelRequest""" | ||
|
||
"""model_name: model_name""" | ||
model_name: str = None | ||
"""model_type: model_type""" | ||
model_type: str = None | ||
"""host: host""" | ||
host: str = None | ||
"""port: port""" | ||
port: int = None | ||
"""manager_host: manager_host""" | ||
manager_host: str = None | ||
"""manager_port: manager_port""" | ||
manager_port: int = None | ||
"""healthy: healthy""" | ||
healthy: bool = True | ||
|
||
"""check_healthy: check_healthy""" | ||
check_healthy: bool = True | ||
prompt_template: str = None | ||
last_heartbeat: str = None | ||
stream_api: str = None | ||
nostream_api: str = None |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
1 change: 0 additions & 1 deletion
1
pilot/server/static/_next/static/_BY-cQzLf2lL8o4uTsVNy/_buildManifest.js
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.