Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(model):add model web management #613

Merged
merged 14 commits into from
Sep 22, 2023
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinxcontrib.autodoc_pydantic",
"sphinxcontrib.autodoc_pydantic_base",
"myst_nb",
"sphinx_copybutton",
"sphinx_panels",
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ toml
myst_nb
sphinx_copybutton
pydata-sphinx-theme==0.13.1
pydantic-settings
furo
112 changes: 112 additions & 0 deletions pilot/connections/conn_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import Optional, Any
from pyspark.sql import SparkSession, DataFrame
from sqlalchemy import text

from pilot.connections.base import BaseConnect


class SparkConnect(BaseConnect):
"""Spark Connect
Args:
Usage:
"""

"""db type"""
db_type: str = "spark"
"""db driver"""
driver: str = "spark"
"""db dialect"""
dialect: str = "sparksql"

def __init__(
self,
file_path: str,
spark_session: Optional[SparkSession] = None,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> None:
"""Initialize the Spark DataFrame from Datasource path
return: Spark DataFrame
"""
self.spark_session = (
spark_session or SparkSession.builder.appName("dbgpt").getOrCreate()
)
self.path = file_path
self.table_name = "temp"
self.df = self.create_df(self.path)

@classmethod
def from_file_path(
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
):
try:
return cls(file_path=file_path, engine_args=engine_args)

except Exception as e:
print("load spark datasource error" + str(e))

def create_df(self, path) -> DataFrame:
"""Create a Spark DataFrame from Datasource path
return: Spark DataFrame
"""
return self.spark_session.read.option("header", "true").csv(path)

def run(self, sql):
# self.log(f"llm ingestion sql query is :\n{sql}")
# self.df = self.create_df(self.path)
self.df.createOrReplaceTempView(self.table_name)
df = self.spark_session.sql(sql)
first_row = df.first()
rows = [first_row.asDict().keys()]
for row in df.collect():
rows.append(row)
return rows

def query_ex(self, sql):
rows = self.run(sql)
field_names = rows[0]
return field_names, rows

def get_indexes(self, table_name):
"""Get table indexes about specified table."""
return ""

def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""

return "ans"

def get_fields(self):
"""Get column meta about dataframe."""
return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes])

def get_users(self):
return []

def get_grants(self):
return []

def get_collation(self):
"""Get collation."""
return "UTF-8"

def get_charset(self):
return "UTF-8"

def get_db_list(self):
return ["default"]

def get_db_names(self):
return ["default"]

def get_database_list(self):
return []

def get_database_names(self):
return []

def table_simple_info(self):
return f"{self.table_name}{self.get_fields()}"

def get_table_comments(self, db_name):
return ""
2 changes: 2 additions & 0 deletions pilot/server/dbgpt_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from fastapi.middleware.cors import CORSMiddleware
from pilot.server.knowledge.api import router as knowledge_router
from pilot.server.prompt.api import router as prompt_router
from pilot.server.llm_manage.api import router as llm_manage_api


from pilot.openapi.api_v1.api_v1 import router as api_v1
Expand Down Expand Up @@ -73,6 +74,7 @@ def swagger_monkey_patch(*args, **kwargs):
app.include_router(api_v1, prefix="/api")
app.include_router(knowledge_router, prefix="/api")
app.include_router(api_editor_route_v1, prefix="/api")
app.include_router(llm_manage_api, prefix="/api")
app.include_router(api_fb_v1, prefix="/api")

# app.include_router(api_v1)
Expand Down
108 changes: 108 additions & 0 deletions pilot/server/llm_manage/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import List

from fastapi import APIRouter

from pilot.component import ComponentType
from pilot.configs.config import Config

from pilot.model.cluster import WorkerStartupRequest, WorkerManagerFactory
from pilot.openapi.api_view_model import Result

from pilot.server.llm_manage.request.request import ModelResponse

CFG = Config()
router = APIRouter()


@router.get("/v1/worker/model/params")
async def model_params():
print(f"/worker/model/params")
try:
from pilot.model.cluster import WorkerManagerFactory

worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
params = []
workers = await worker_manager.supported_models()
for worker in workers:
for model in worker.models:
model_dict = model.__dict__
model_dict["host"] = worker.host
model_dict["port"] = worker.port
params.append(model_dict)
return Result.succ(params)
if not worker_instance:
return Result.faild(code="E000X", msg=f"can not find worker manager")
except Exception as e:
return Result.faild(code="E000X", msg=f"model stop failed {e}")


@router.get("/v1/worker/model/list")
async def model_list():
print(f"/worker/model/list")
try:
from pilot.model.cluster.controller.controller import BaseModelController

controller = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CONTROLLER, BaseModelController
)
responses = []
managers = await controller.get_all_instances(
model_name="WorkerManager@service", healthy_only=True
)
manager_map = dict(map(lambda manager: (manager.host, manager), managers))
models = await controller.get_all_instances()
for model in models:
worker_name, worker_type = model.model_name.split("@")
if worker_type == "llm" or worker_type == "text2vec":
response = ModelResponse(
model_name=worker_name,
model_type=worker_type,
host=model.host,
port=model.port,
healthy=model.healthy,
check_healthy=model.check_healthy,
last_heartbeat=model.last_heartbeat,
prompt_template=model.prompt_template,
)
response.manager_host = model.host if manager_map[model.host] else None
response.manager_port = (
manager_map[model.host].port if manager_map[model.host] else None
)
responses.append(response)
return Result.succ(responses)

except Exception as e:
return Result.faild(code="E000X", msg=f"space list error {e}")


@router.post("/v1/worker/model/stop")
async def model_stop(request: WorkerStartupRequest):
print(f"/v1/worker/model/stop:")
try:
from pilot.model.cluster.controller.controller import BaseModelController

worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
if not worker_manager:
return Result.faild(code="E000X", msg=f"can not find worker manager")
request.params = {}
return Result.succ(await worker_manager.model_shutdown(request))
except Exception as e:
return Result.faild(code="E000X", msg=f"model stop failed {e}")


@router.post("/v1/worker/model/start")
async def model_start(request: WorkerStartupRequest):
print(f"/v1/worker/model/start:")
try:
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
if not worker_manager:
return Result.faild(code="E000X", msg=f"can not find worker manager")
return Result.succ(await worker_manager.model_startup(request))
except Exception as e:
return Result.faild(code="E000X", msg=f"model start failed {e}")
28 changes: 28 additions & 0 deletions pilot/server/llm_manage/request/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from dataclasses import dataclass


@dataclass
class ModelResponse:
"""ModelRequest"""

"""model_name: model_name"""
model_name: str = None
"""model_type: model_type"""
model_type: str = None
"""host: host"""
host: str = None
"""port: port"""
port: int = None
"""manager_host: manager_host"""
manager_host: str = None
"""manager_port: manager_port"""
manager_port: int = None
"""healthy: healthy"""
healthy: bool = True

"""check_healthy: check_healthy"""
check_healthy: bool = True
prompt_template: str = None
last_heartbeat: str = None
stream_api: str = None
nostream_api: str = None
2 changes: 1 addition & 1 deletion pilot/server/static/404.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pilot/server/static/404/index.html

Large diffs are not rendered by default.

This file was deleted.

Loading