Skip to content

Commit

Permalink
remove gateway.
Browse files Browse the repository at this point in the history
  • Loading branch information
lkk12014402 committed Dec 13, 2024
1 parent 2442af3 commit 09ac77f
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 55 deletions.
17 changes: 11 additions & 6 deletions CodeGen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import asyncio
import os

from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType, ServiceRoleType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand All @@ -13,6 +13,7 @@
UsageInfo,
)
from comps.cores.proto.docarray import LLMParams
from comps.cores.mega.utils import handle_message
from fastapi import Request
from fastapi.responses import StreamingResponse

Expand All @@ -21,11 +22,12 @@
LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))


class CodeGenService(Gateway):
class CodeGenService:
def __init__(self, host="0.0.0.0", port=8000):
self.host = host
self.port = port
self.megaservice = ServiceOrchestrator()
self.endpoint = str(MegaServiceEndpoint.CODE_GEN)

def add_remote_service(self):
llm = MicroService(
Expand All @@ -42,7 +44,7 @@ async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
prompt = handle_message(chat_request.messages)
parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
Expand Down Expand Up @@ -78,14 +80,17 @@ async def handle_request(self, request: Request):
return ChatCompletionResponse(model="codegen", choices=choices, usage=usage)

def start(self):
super().__init__(
megaservice=self.megaservice,
self.service = MicroService(
self.__class__.__name__,
service_role=ServiceRoleType.MEGASERVICE,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.CODE_GEN),
endpoint=self.endpoint,
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
self.service.add_route(self.endpoint, self.handle_request, methods=["POST"])
self.service.start()


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions CodeTrans/code_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import asyncio
import os

from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType, ServiceRoleType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand All @@ -20,11 +20,12 @@
LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))


class CodeTransService(Gateway):
class CodeTransService:
def __init__(self, host="0.0.0.0", port=8000):
self.host = host
self.port = port
self.megaservice = ServiceOrchestrator()
self.endpoint = str(MegaServiceEndpoint.CODE_TRANS)

def add_remote_service(self):
llm = MicroService(
Expand Down Expand Up @@ -77,14 +78,17 @@ async def handle_request(self, request: Request):
return ChatCompletionResponse(model="codetrans", choices=choices, usage=usage)

def start(self):
super().__init__(
megaservice=self.megaservice,
self.service = MicroService(
self.__class__.__name__,
service_role=ServiceRoleType.MEGASERVICE,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.CODE_TRANS),
endpoint=self.endpoint,
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
self.service.add_route(self.endpoint, self.handle_request, methods=["POST"])
self.service.start()


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions DocIndexRetriever/retrieval_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from typing import Union

from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType, ServiceRoleType
from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingRequest
from comps.cores.proto.docarray import LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc
from fastapi import Request
Expand All @@ -21,11 +21,12 @@
RERANK_SERVICE_PORT = os.getenv("RERANK_SERVICE_PORT", 8000)


class RetrievalToolService(Gateway):
class RetrievalToolService:
def __init__(self, host="0.0.0.0", port=8000):
self.host = host
self.port = port
self.megaservice = ServiceOrchestrator()
self.endpoint = str(MegaServiceEndpoint.RETRIEVALTOOL)

def add_remote_service(self):
embedding = MicroService(
Expand Down Expand Up @@ -116,14 +117,17 @@ def parser_input(data, TypeClass, key):
return response

def start(self):
super().__init__(
megaservice=self.megaservice,
self.service = MicroService(
self.__class__.__name__,
service_role=ServiceRoleType.MEGASERVICE,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.RETRIEVALTOOL),
endpoint=self.endpoint,
input_datatype=Union[TextDoc, EmbeddingRequest, ChatCompletionRequest],
output_datatype=Union[RerankedDoc, LLMParamsDoc],
)
self.service.add_route(self.endpoint, self.handle_request, methods=["POST"])
self.service.start()

def add_remote_service_without_rerank(self):
embedding = MicroService(
Expand Down
14 changes: 9 additions & 5 deletions EdgeCraftRAG/chatqna.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
PIPELINE_SERVICE_HOST_IP = os.getenv("PIPELINE_SERVICE_HOST_IP", "127.0.0.1")
PIPELINE_SERVICE_PORT = int(os.getenv("PIPELINE_SERVICE_PORT", 16010))

from comps import Gateway, MegaServiceEndpoint
from comps import MegaServiceEndpoint, ServiceRoleType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand All @@ -22,11 +22,12 @@
from fastapi.responses import StreamingResponse


class EdgeCraftRagService(Gateway):
class EdgeCraftRagService:
def __init__(self, host="0.0.0.0", port=16010):
self.host = host
self.port = port
self.megaservice = ServiceOrchestrator()
self.endpoint = str(MegaServiceEndpoint.CHAT_QNA)

def add_remote_service(self):
edgecraftrag = MicroService(
Expand Down Expand Up @@ -72,14 +73,17 @@ async def handle_request(self, request: Request):
return ChatCompletionResponse(model="edgecraftrag", choices=choices, usage=usage)

def start(self):
super().__init__(
megaservice=self.megaservice,
self.service = MicroService(
self.__class__.__name__,
service_role=ServiceRoleType.MEGASERVICE,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.CHAT_QNA),
endpoint=self.endpoint,
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
self.service.add_route(self.endpoint, self.handle_request, methods=["POST"])
self.service.start()


if __name__ == "__main__":
Expand Down
17 changes: 11 additions & 6 deletions GraphRAG/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import re

from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType, ServiceRoleType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand All @@ -16,6 +16,7 @@
UsageInfo,
)
from comps.cores.proto.docarray import LLMParams, RetrieverParms, TextDoc
from comps.cores.mega.utils import handle_message
from fastapi import Request
from fastapi.responses import StreamingResponse
from langchain_core.prompts import PromptTemplate
Expand Down Expand Up @@ -127,14 +128,15 @@ def align_generator(self, gen, **kwargs):
yield "data: [DONE]\n\n"


class GraphRAGService(Gateway):
class GraphRAGService:
def __init__(self, host="0.0.0.0", port=8000):
self.host = host
self.port = port
ServiceOrchestrator.align_inputs = align_inputs
ServiceOrchestrator.align_outputs = align_outputs
ServiceOrchestrator.align_generator = align_generator
self.megaservice = ServiceOrchestrator()
self.endpoint = str(MegaServiceEndpoint.GRAPH_RAG)

def add_remote_service(self):
retriever = MicroService(
Expand Down Expand Up @@ -180,7 +182,7 @@ def parser_input(data, TypeClass, key):
raise ValueError(f"Unknown request type: {data}")
if chat_request is None:
raise ValueError(f"Unknown request type: {data}")
prompt = self._handle_message(chat_request.messages)
prompt = handle_message(chat_request.messages)
parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
Expand Down Expand Up @@ -223,14 +225,17 @@ def parser_input(data, TypeClass, key):
return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage)

def start(self):
super().__init__(
megaservice=self.megaservice,
self.service = MicroService(
self.__class__.__name__,
service_role=ServiceRoleType.MEGASERVICE,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.GRAPH_RAG),
endpoint=self.endpoint,
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
self.service.add_route(self.endpoint, self.handle_request, methods=["POST"])
self.service.start()


if __name__ == "__main__":
Expand Down
15 changes: 9 additions & 6 deletions MultimodalQnA/multimodalqna.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from io import BytesIO

import requests
from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType, ServiceRoleType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand All @@ -29,7 +29,7 @@
LVM_SERVICE_PORT = int(os.getenv("LVM_SERVICE_PORT", 9399))


class MultimodalQnAService(Gateway):
class MultimodalQnAService:
asr_port = int(os.getenv("ASR_SERVICE_PORT", 3001))
asr_endpoint = os.getenv("ASR_SERVICE_ENDPOINT", "http://0.0.0.0:{}/v1/audio/transcriptions".format(asr_port))

Expand All @@ -38,6 +38,7 @@ def __init__(self, host="0.0.0.0", port=8000):
self.port = port
self.lvm_megaservice = ServiceOrchestrator()
self.megaservice = ServiceOrchestrator()
self.endpoint = str(MegaServiceEndpoint.MULTIMODAL_QNA)

def add_remote_service(self):
mm_embedding = MicroService(
Expand Down Expand Up @@ -74,7 +75,6 @@ def add_remote_service(self):
# for lvm megaservice
self.lvm_megaservice.add(lvm)

# this overrides _handle_message method of Gateway
def _handle_message(self, messages):
images = []
audios = []
Expand Down Expand Up @@ -303,14 +303,17 @@ async def handle_request(self, request: Request):
return ChatCompletionResponse(model="multimodalqna", choices=choices, usage=usage)

def start(self):
super().__init__(
megaservice=self.megaservice,
self.service = MicroService(
self.__class__.__name__,
service_role=ServiceRoleType.MEGASERVICE,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.MULTIMODAL_QNA),
endpoint=self.endpoint,
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
self.service.add_route(self.endpoint, self.handle_request, methods=["POST"])
self.service.start()


if __name__ == "__main__":
Expand Down
17 changes: 11 additions & 6 deletions SearchQnA/searchqna.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import os

from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType, ServiceRoleType
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand All @@ -12,6 +12,7 @@
UsageInfo,
)
from comps.cores.proto.docarray import LLMParams
from comps.cores.mega.utils import handle_message
from fastapi import Request
from fastapi.responses import StreamingResponse

Expand All @@ -26,11 +27,12 @@
LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))


class SearchQnAService(Gateway):
class SearchQnAService:
def __init__(self, host="0.0.0.0", port=8000):
self.host = host
self.port = port
self.megaservice = ServiceOrchestrator()
self.endpoint = str(MegaServiceEndpoint.SEARCH_QNA)

def add_remote_service(self):
embedding = MicroService(
Expand Down Expand Up @@ -74,7 +76,7 @@ async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
prompt = handle_message(chat_request.messages)
parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
Expand Down Expand Up @@ -110,14 +112,17 @@ async def handle_request(self, request: Request):
return ChatCompletionResponse(model="searchqna", choices=choices, usage=usage)

def start(self):
super().__init__(
megaservice=self.megaservice,
self.service = MicroService(
self.__class__.__name__,
service_role=ServiceRoleType.MEGASERVICE,
host=self.host,
port=self.port,
endpoint=str(MegaServiceEndpoint.SEARCH_QNA),
endpoint=self.endpoint,
input_datatype=ChatCompletionRequest,
output_datatype=ChatCompletionResponse,
)
self.service.add_route(self.endpoint, self.handle_request, methods=["POST"])
self.service.start()


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 09ac77f

Please sign in to comment.