Skip to content

Commit

Permalink
feat(model):add model web management (#613)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Sep 22, 2023
2 parents 3590d7b + d512dde commit 9979b6a
Show file tree
Hide file tree
Showing 43 changed files with 373 additions and 111 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinxcontrib.autodoc_pydantic",
"sphinxcontrib.autodoc_pydantic_base",
"myst_nb",
"sphinx_copybutton",
"sphinx_panels",
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ toml
myst_nb
sphinx_copybutton
pydata-sphinx-theme==0.13.1
pydantic-settings
furo
112 changes: 112 additions & 0 deletions pilot/connections/conn_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import Optional, Any
from pyspark.sql import SparkSession, DataFrame
from sqlalchemy import text

from pilot.connections.base import BaseConnect


class SparkConnect(BaseConnect):
"""Spark Connect
Args:
Usage:
"""

"""db type"""
db_type: str = "spark"
"""db driver"""
driver: str = "spark"
"""db dialect"""
dialect: str = "sparksql"

def __init__(
self,
file_path: str,
spark_session: Optional[SparkSession] = None,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> None:
"""Initialize the Spark DataFrame from Datasource path
return: Spark DataFrame
"""
self.spark_session = (
spark_session or SparkSession.builder.appName("dbgpt").getOrCreate()
)
self.path = file_path
self.table_name = "temp"
self.df = self.create_df(self.path)

@classmethod
def from_file_path(
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
):
try:
return cls(file_path=file_path, engine_args=engine_args)

except Exception as e:
print("load spark datasource error" + str(e))

def create_df(self, path) -> DataFrame:
"""Create a Spark DataFrame from Datasource path
return: Spark DataFrame
"""
return self.spark_session.read.option("header", "true").csv(path)

def run(self, sql):
# self.log(f"llm ingestion sql query is :\n{sql}")
# self.df = self.create_df(self.path)
self.df.createOrReplaceTempView(self.table_name)
df = self.spark_session.sql(sql)
first_row = df.first()
rows = [first_row.asDict().keys()]
for row in df.collect():
rows.append(row)
return rows

def query_ex(self, sql):
rows = self.run(sql)
field_names = rows[0]
return field_names, rows

def get_indexes(self, table_name):
"""Get table indexes about specified table."""
return ""

def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""

return "ans"

def get_fields(self):
"""Get column meta about dataframe."""
return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes])

def get_users(self):
return []

def get_grants(self):
return []

def get_collation(self):
"""Get collation."""
return "UTF-8"

def get_charset(self):
return "UTF-8"

def get_db_list(self):
return ["default"]

def get_db_names(self):
return ["default"]

def get_database_list(self):
return []

def get_database_names(self):
return []

def table_simple_info(self):
return f"{self.table_name}{self.get_fields()}"

def get_table_comments(self, db_name):
return ""
2 changes: 2 additions & 0 deletions pilot/server/dbgpt_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from fastapi.middleware.cors import CORSMiddleware
from pilot.server.knowledge.api import router as knowledge_router
from pilot.server.prompt.api import router as prompt_router
from pilot.server.llm_manage.api import router as llm_manage_api


from pilot.openapi.api_v1.api_v1 import router as api_v1
Expand Down Expand Up @@ -73,6 +74,7 @@ def swagger_monkey_patch(*args, **kwargs):
app.include_router(api_v1, prefix="/api")
app.include_router(knowledge_router, prefix="/api")
app.include_router(api_editor_route_v1, prefix="/api")
app.include_router(llm_manage_api, prefix="/api")
app.include_router(api_fb_v1, prefix="/api")

# app.include_router(api_v1)
Expand Down
108 changes: 108 additions & 0 deletions pilot/server/llm_manage/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import List

from fastapi import APIRouter

from pilot.component import ComponentType
from pilot.configs.config import Config

from pilot.model.cluster import WorkerStartupRequest, WorkerManagerFactory
from pilot.openapi.api_view_model import Result

from pilot.server.llm_manage.request.request import ModelResponse

CFG = Config()
router = APIRouter()


@router.get("/v1/worker/model/params")
async def model_params():
print(f"/worker/model/params")
try:
from pilot.model.cluster import WorkerManagerFactory

worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
params = []
workers = await worker_manager.supported_models()
for worker in workers:
for model in worker.models:
model_dict = model.__dict__
model_dict["host"] = worker.host
model_dict["port"] = worker.port
params.append(model_dict)
return Result.succ(params)
if not worker_instance:
return Result.faild(code="E000X", msg=f"can not find worker manager")
except Exception as e:
return Result.faild(code="E000X", msg=f"model stop failed {e}")


@router.get("/v1/worker/model/list")
async def model_list():
print(f"/worker/model/list")
try:
from pilot.model.cluster.controller.controller import BaseModelController

controller = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CONTROLLER, BaseModelController
)
responses = []
managers = await controller.get_all_instances(
model_name="WorkerManager@service", healthy_only=True
)
manager_map = dict(map(lambda manager: (manager.host, manager), managers))
models = await controller.get_all_instances()
for model in models:
worker_name, worker_type = model.model_name.split("@")
if worker_type == "llm" or worker_type == "text2vec":
response = ModelResponse(
model_name=worker_name,
model_type=worker_type,
host=model.host,
port=model.port,
healthy=model.healthy,
check_healthy=model.check_healthy,
last_heartbeat=model.last_heartbeat,
prompt_template=model.prompt_template,
)
response.manager_host = model.host if manager_map[model.host] else None
response.manager_port = (
manager_map[model.host].port if manager_map[model.host] else None
)
responses.append(response)
return Result.succ(responses)

except Exception as e:
return Result.faild(code="E000X", msg=f"space list error {e}")


@router.post("/v1/worker/model/stop")
async def model_stop(request: WorkerStartupRequest):
print(f"/v1/worker/model/stop:")
try:
from pilot.model.cluster.controller.controller import BaseModelController

worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
if not worker_manager:
return Result.faild(code="E000X", msg=f"can not find worker manager")
request.params = {}
return Result.succ(await worker_manager.model_shutdown(request))
except Exception as e:
return Result.faild(code="E000X", msg=f"model stop failed {e}")


@router.post("/v1/worker/model/start")
async def model_start(request: WorkerStartupRequest):
print(f"/v1/worker/model/start:")
try:
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
if not worker_manager:
return Result.faild(code="E000X", msg=f"can not find worker manager")
return Result.succ(await worker_manager.model_startup(request))
except Exception as e:
return Result.faild(code="E000X", msg=f"model start failed {e}")
28 changes: 28 additions & 0 deletions pilot/server/llm_manage/request/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from dataclasses import dataclass


@dataclass
class ModelResponse:
"""ModelRequest"""

"""model_name: model_name"""
model_name: str = None
"""model_type: model_type"""
model_type: str = None
"""host: host"""
host: str = None
"""port: port"""
port: int = None
"""manager_host: manager_host"""
manager_host: str = None
"""manager_port: manager_port"""
manager_port: int = None
"""healthy: healthy"""
healthy: bool = True

"""check_healthy: check_healthy"""
check_healthy: bool = True
prompt_template: str = None
last_heartbeat: str = None
stream_api: str = None
nostream_api: str = None
2 changes: 1 addition & 1 deletion pilot/server/static/404.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pilot/server/static/404/index.html

Large diffs are not rendered by default.

This file was deleted.

Loading

0 comments on commit 9979b6a

Please sign in to comment.