From cb5ea242e0817a2082fda672ccc8a7ddce515ff3 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Thu, 11 Jul 2024 21:44:30 +0800 Subject: [PATCH] pydantic model_validate and model_dump apis update --- langport/core/cluster_node.py | 24 +++++++-------- langport/routers/gateway/common.py | 12 ++++---- langport/routers/gateway/openai_compatible.py | 10 +++---- langport/routers/server/core_node.py | 30 +++++++++---------- langport/routers/server/embedding_node.py | 2 +- langport/routers/server/generation_node.py | 11 ++++--- .../service/gateway/cluster_monitor_app.py | 4 +-- langport/service/gateway/graphite_feeder.py | 4 +-- langport/workers/generation_worker.py | 2 +- 9 files changed, 49 insertions(+), 50 deletions(-) diff --git a/langport/core/cluster_node.py b/langport/core/cluster_node.py index 5a48bbc..547879d 100644 --- a/langport/core/cluster_node.py +++ b/langport/core/cluster_node.py @@ -116,10 +116,10 @@ async def get_node_info(self, node_addr: str) -> NodeInfo: response = await self.client.post( node_addr + "/node_info", headers=self.headers, - json=NodeInfoRequest(node_id=self.node_id).dict(), + json=NodeInfoRequest(node_id=self.node_id).model_dump(), timeout=WORKER_API_TIMEOUT, ) - remote_node_info = NodeInfoResponse.parse_obj(response.json()) + remote_node_info = NodeInfoResponse.model_validate(response.json()) return remote_node_info.node_info async def register_node(self, target_node_addr: str, register_node_id: str, register_node_addr: str) -> bool: @@ -138,10 +138,10 @@ async def register_node(self, target_node_addr: str, register_node_id: str, regi response = await self.client.post( target_node_addr + "/register_node", headers=self.headers, - json=data.dict(), + json=data.model_dump(), timeout=WORKER_API_TIMEOUT, ) - remote = RegisterNodeResponse.parse_obj(response.json()) + remote = RegisterNodeResponse.model_validate(response.json()) self._add_node(remote.node_id, remote.node_addr) # fetch remote node info @@ -175,10 +175,10 @@ async def remove_node(self, target_node_addr: str, removed_node_id: str) -> bool response = await self.client.post( target_node_addr + "/remove_node", headers=self.headers, - json=data.dict(), + json=data.model_dump(), timeout=WORKER_API_TIMEOUT, ) - ret = RemoveNodeResponse.parse_obj(response.json()) + ret = RemoveNodeResponse.model_validate(response.json()) return True @@ -200,10 +200,10 @@ async def send_heartbeat(self, node_addr: str): response = await self.client.post( node_addr + "/heartbeat", headers=self.headers, - json=data.dict(), + json=data.model_dump(), timeout=WORKER_API_TIMEOUT, ) - ret = HeartbeatPong.parse_obj(response.json()) + ret = HeartbeatPong.model_validate(response.json()) return ret @@ -230,10 +230,10 @@ async def fetch_all_nodes(self, node_addr: str): response = await self.client.post( node_addr + "/node_list", headers=self.headers, - json=data.dict(), + json=data.model_dump(), timeout=WORKER_API_TIMEOUT, ) - ret = NodeListResponse.parse_obj(response.json()) + ret = NodeListResponse.model_validate(response.json()) return ret @@ -255,10 +255,10 @@ async def request_node_state(self, node_addr: str, name: str) -> GetNodeStateRes response = await self.client.post( node_addr + "/get_node_state", headers=self.headers, - json=data.dict(), + json=data.model_dump(), timeout=WORKER_API_TIMEOUT, ) - ret = GetNodeStateResponse.parse_obj(response.json()) + ret = GetNodeStateResponse.model_validate(response.json()) return ret diff --git a/langport/routers/gateway/common.py b/langport/routers/gateway/common.py index 084cb82..296e96b 100644 --- a/langport/routers/gateway/common.py +++ b/langport/routers/gateway/common.py @@ -43,12 +43,12 @@ class AppSettings(BaseSettings): def create_server_error_response(code: int, message: str) -> JSONResponse: return JSONResponse( - ErrorResponse(message=message, code=code).dict(), status_code=500 + ErrorResponse(message=message, code=code).model_dump(), status_code=500 ) def create_bad_request_response(code: int, message: str) -> JSONResponse: return JSONResponse( - ErrorResponse(message=message, code=code).dict(), status_code=400 + ErrorResponse(message=message, code=code).model_dump(), status_code=400 ) @@ -79,9 +79,9 @@ async def _get_worker_address( raise Exception("Error dispatch method.") ret = await client.post( controller_address + "/get_worker_address", - json=payload.dict(), + json=payload.model_dump(), ) - response = WorkerAddressResponse.parse_obj(ret.json()) + response = WorkerAddressResponse.model_validate(ret.json()) address_list = response.address_list values = [json.loads(obj) for obj in response.values] @@ -122,14 +122,14 @@ async def _list_models(app_settings: AppSettings, feature: Optional[str], client try: ret = await client.post( controller_address + "/get_worker_address", - json=payload.dict(), + json=payload.model_dump(), ) if ret.status_code != 200: return [] except Exception as e: print("[Exception] list model: ", e) return [] - response = WorkerAddressResponse.parse_obj(ret.json()) + response = WorkerAddressResponse.model_validate(ret.json()) address_list = response.address_list models = [json.loads(obj) for obj in response.values] diff --git a/langport/routers/gateway/openai_compatible.py b/langport/routers/gateway/openai_compatible.py index 15176db..cddfdbe 100644 --- a/langport/routers/gateway/openai_compatible.py +++ b/langport/routers/gateway/openai_compatible.py @@ -166,7 +166,7 @@ async def generate_completion_stream_generator(app_settings: AppSettings, payloa previous_text = "" async for content in generate_completion_stream(app_settings, "/completion_stream", payload): if content.error_code != ErrorCode.OK: - yield f"data: {json.dumps(content.dict(), ensure_ascii=False)}\n\n" + yield f"data: {json.dumps(content.model_dump(), ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" return decoded_unicode = content.text.replace("\ufffd", "") @@ -269,12 +269,12 @@ async def chat_completion_stream_generator( chunk = ChatCompletionStreamResponse( id=id, choices=[choice_data], model=payload["model"] ) - yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n" previous_text = "" async for content in generate_completion_stream(app_settings, "/chat_stream", payload): if content.error_code != ErrorCode.OK: - yield f"data: {json.dumps(content.dict(), ensure_ascii=False)}\n\n" + yield f"data: {json.dumps(content.model_dump(), ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" return decoded_unicode = content.text.replace("\ufffd", "") @@ -479,7 +479,7 @@ async def api_embeddings(app_settings: AppSettings, request: EmbeddingsRequest): total_tokens=response.usage.total_tokens, completion_tokens=None, ), - ).dict(exclude_none=True) + ).model_dump(exclude_none=True) elif request.encoding_format == "base64": return EmbeddingsResponse( data=[EmbeddingsData( @@ -493,6 +493,6 @@ async def api_embeddings(app_settings: AppSettings, request: EmbeddingsRequest): total_tokens=response.usage.total_tokens, completion_tokens=None, ), - ).dict(exclude_none=True) + ).model_dump(exclude_none=True) else: raise Exception("Invalid encoding_format param.") \ No newline at end of file diff --git a/langport/routers/server/core_node.py b/langport/routers/server/core_node.py index cbf16be..041c783 100644 --- a/langport/routers/server/core_node.py +++ b/langport/routers/server/core_node.py @@ -1,5 +1,5 @@ from fastapi import FastAPI, BackgroundTasks -from langport.protocol.worker_protocol import GetNodeStateRequest, HeartbeatPing, NodeInfoRequest, NodeListRequest, RegisterNodeRequest, RemoveNodeRequest, WorkerAddressRequest +from langport.protocol.worker_protocol import GetNodeStateRequest, GetNodeStateResponse, HeartbeatPing, HeartbeatPong, NodeInfoRequest, NodeInfoResponse, NodeListRequest, NodeListResponse, RegisterNodeRequest, RegisterNodeResponse, RemoveNodeRequest, RemoveNodeResponse, WorkerAddressRequest, WorkerAddressResponse app = FastAPI() @@ -20,40 +20,40 @@ async def shutdown_event(): @app.post("/register_node") async def register_node(request: RegisterNodeRequest): - response = await app.node.api_register_node(request) - return response.dict() + response: RegisterNodeResponse = await app.node.api_register_node(request) + return response.model_dump() @app.post("/remove_node") async def remove_node(request: RemoveNodeRequest): - response = await app.node.api_remove_node(request) - return response.dict() + response: RemoveNodeResponse = await app.node.api_remove_node(request) + return response.model_dump() @app.post("/heartbeat") async def receive_heartbeat(request: HeartbeatPing): - response = await app.node.api_receive_heartbeat(request) - return response.dict() + response: HeartbeatPong = await app.node.api_receive_heartbeat(request) + return response.model_dump() @app.post("/node_list") async def return_node_list(request: NodeListRequest): - response = await app.node.api_return_node_list(request) - return response.dict() + response: NodeListResponse = await app.node.api_return_node_list(request) + return response.model_dump() @app.post("/node_info") async def return_node_info(request: NodeInfoRequest): - response = await app.node.api_return_node_info(request) - return response.dict() + response: NodeInfoResponse = await app.node.api_return_node_info(request) + return response.model_dump() @app.post("/get_node_state") async def api_return_node_state(request: GetNodeStateRequest): - response = await app.node.api_return_node_state(request) - return response.dict() + response: GetNodeStateResponse = await app.node.api_return_node_state(request) + return response.model_dump() @app.post("/get_worker_address") async def api_get_worker_address(request: WorkerAddressRequest): - response = await app.node.api_get_worker_address(request) - return response.dict() + response: WorkerAddressResponse = await app.node.api_get_worker_address(request) + return response.model_dump() diff --git a/langport/routers/server/embedding_node.py b/langport/routers/server/embedding_node.py index b8a5002..4e887c7 100644 --- a/langport/routers/server/embedding_node.py +++ b/langport/routers/server/embedding_node.py @@ -10,4 +10,4 @@ async def api_embeddings(request: EmbeddingsTask): await app.node.acquire_model_semaphore() embedding = await app.node.get_embeddings(request) background_tasks = create_background_tasks(app.node) - return JSONResponse(content=embedding.dict(), background=background_tasks) + return JSONResponse(content=embedding.model_dump(), background=background_tasks) diff --git a/langport/routers/server/generation_node.py b/langport/routers/server/generation_node.py index 7e420b4..3f6cecc 100644 --- a/langport/routers/server/generation_node.py +++ b/langport/routers/server/generation_node.py @@ -1,7 +1,8 @@ +from typing import Optional from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse, JSONResponse -from langport.protocol.worker_protocol import EmbeddingsTask, GenerationTask +from langport.protocol.worker_protocol import BaseWorkerResult, EmbeddingsTask, GenerationTask from .core_node import app, create_background_tasks @@ -38,13 +39,11 @@ async def api_chat(request: Request): echo=params.get("echo", False), stop_token_ids=params.get("stop_token_ids", None), )) - completion = None + completion: Optional[BaseWorkerResult] = None for chunk in generator: completion = chunk background_tasks = create_background_tasks(app.node) - return JSONResponse(content=completion.dict(), background=background_tasks) - - + return JSONResponse(content=completion.model_dump(), background=background_tasks) @app.post("/completion_stream") async def api_completion_stream(request: Request): @@ -86,4 +85,4 @@ async def api_completion(request: Request): for chunk in generator: completion = chunk background_tasks = create_background_tasks(app.node) - return JSONResponse(content=completion.dict(), background=background_tasks) + return JSONResponse(content=completion.model_dump(), background=background_tasks) diff --git a/langport/service/gateway/cluster_monitor_app.py b/langport/service/gateway/cluster_monitor_app.py index b89b481..ce875f6 100644 --- a/langport/service/gateway/cluster_monitor_app.py +++ b/langport/service/gateway/cluster_monitor_app.py @@ -20,11 +20,11 @@ async def list_workers(app_settings: AppSettings): ) ret = await client.post( app_settings.controller_address + "/get_worker_address", - json=payload.dict(), + json=payload.model_dump(), ) if ret.status_code != 200: return [] - response = WorkerAddressResponse.parse_obj(ret.json()) + response = WorkerAddressResponse.model_validate(ret.json()) address_list = response.address_list data = [json.loads(obj) for obj in response.values] diff --git a/langport/service/gateway/graphite_feeder.py b/langport/service/gateway/graphite_feeder.py index d44431e..86c2b6a 100644 --- a/langport/service/gateway/graphite_feeder.py +++ b/langport/service/gateway/graphite_feeder.py @@ -26,11 +26,11 @@ async def list_workers(app_settings: AppSettings): ) ret = await client.post( app_settings.controller_address + "/get_worker_address", - json=payload.dict(), + json=payload.model_dump(), ) if ret.status_code != 200: return [] - response = WorkerAddressResponse.parse_obj(ret.json()) + response = WorkerAddressResponse.model_validate(ret.json()) address_list = response.address_list data = [json.loads(obj) for obj in response.values] diff --git a/langport/workers/generation_worker.py b/langport/workers/generation_worker.py index 3e431d7..562234d 100644 --- a/langport/workers/generation_worker.py +++ b/langport/workers/generation_worker.py @@ -93,4 +93,4 @@ async def generation_stream(self, task: GenerationTask): async def generation_bytes_stream(self, task: GenerationTask): async for chunk in self.generation_stream(task): - yield json.dumps(chunk.dict()).encode() + b"\0" + yield json.dumps(chunk.model_dump()).encode() + b"\0"