diff --git a/dbgpt/cli/cli_scripts.py b/dbgpt/cli/cli_scripts.py index bcf065639..5b170f9c2 100644 --- a/dbgpt/cli/cli_scripts.py +++ b/dbgpt/cli/cli_scripts.py @@ -82,6 +82,12 @@ def run(): pass +@click.group() +def net(): + """Net tools.""" + pass + + stop_all_func_list = [] @@ -100,6 +106,7 @@ def stop_all(): cli.add_command(app) cli.add_command(repo) cli.add_command(run) +cli.add_command(net) add_command_alias(stop_all, name="all", parent_group=stop) try: @@ -200,6 +207,13 @@ def stop_all(): except ImportError as e: logging.warning(f"Integrating dbgpt client command line tool failed: {e}") +try: + from dbgpt.util.network._cli import start_forward + + add_command_alias(start_forward, name="forward", parent_group=net) +except ImportError as e: + logging.warning(f"Integrating dbgpt net command line tool failed: {e}") + def main(): return cli() diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index 3e5c0c8a5..c07c390e4 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -108,6 +108,7 @@ def __init__(self, label: str, description: str): _OPERATOR_CATEGORY_DETAIL = { "trigger": _CategoryDetail("Trigger", "Trigger your AWEL flow"), + "sender": _CategoryDetail("Sender", "Send the data to the target"), "llm": _CategoryDetail("LLM", "Invoke LLM model"), "conversion": _CategoryDetail("Conversion", "Handle the conversion"), "output_parser": _CategoryDetail("Output Parser", "Parse the output of LLM model"), @@ -121,6 +122,7 @@ class OperatorCategory(str, Enum): """The category of the operator.""" TRIGGER = "trigger" + SENDER = "sender" LLM = "llm" CONVERSION = "conversion" OUTPUT_PARSER = "output_parser" diff --git a/dbgpt/core/awel/flow/flow_factory.py b/dbgpt/core/awel/flow/flow_factory.py index d77912335..fbbd9bca6 100644 --- a/dbgpt/core/awel/flow/flow_factory.py +++ b/dbgpt/core/awel/flow/flow_factory.py @@ -20,6 +20,7 @@ from .exceptions import ( FlowClassMetadataException, FlowDAGMetadataException, + FlowException, FlowMetadataException, ) @@ -720,5 +721,5 @@ def fill_flow_panel(flow_panel: FlowPanel): param.default = new_param.default param.placeholder = new_param.placeholder - except ValueError as e: + except (FlowException, ValueError) as e: logger.warning(f"Unable to fill the flow panel: {e}") diff --git a/dbgpt/core/awel/trigger/ext_http_trigger.py b/dbgpt/core/awel/trigger/ext_http_trigger.py index 2af9a1cbf..06ae025f7 100644 --- a/dbgpt/core/awel/trigger/ext_http_trigger.py +++ b/dbgpt/core/awel/trigger/ext_http_trigger.py @@ -3,13 +3,14 @@ Supports more trigger types, such as RequestHttpTrigger. """ from enum import Enum -from typing import List, Optional, Type, Union +from typing import Dict, List, Optional, Type, Union from starlette.requests import Request from dbgpt.util.i18n_utils import _ -from ..flow import IOField, OperatorCategory, OperatorType, ViewMetadata +from ..flow import IOField, OperatorCategory, OperatorType, Parameter, ViewMetadata +from ..operators.common_operator import MapOperator from .http_trigger import ( _PARAMETER_ENDPOINT, _PARAMETER_MEDIA_TYPE, @@ -82,3 +83,122 @@ def __init__( register_to_app=True, **kwargs, ) + + +class DictHTTPSender(MapOperator[Dict, Dict]): + """HTTP Sender operator for AWEL.""" + + metadata = ViewMetadata( + label=_("HTTP Sender"), + name="awel_dict_http_sender", + category=OperatorCategory.SENDER, + description=_("Send a HTTP request to a specified endpoint"), + inputs=[ + IOField.build_from( + _("Request Body"), + "request_body", + dict, + description=_("The request body to send"), + ) + ], + outputs=[ + IOField.build_from( + _("Response Body"), + "response_body", + dict, + description=_("The response body of the HTTP request"), + ) + ], + parameters=[ + Parameter.build_from( + _("HTTP Address"), + _("address"), + type=str, + description=_("The address to send the HTTP request to"), + ), + _PARAMETER_METHODS_ALL.new(), + _PARAMETER_STATUS_CODE.new(), + Parameter.build_from( + _("Timeout"), + "timeout", + type=int, + optional=True, + default=60, + description=_("The timeout of the HTTP request in seconds"), + ), + Parameter.build_from( + _("Token"), + "token", + type=str, + optional=True, + default=None, + description=_("The token to use for the HTTP request"), + ), + Parameter.build_from( + _("Cookies"), + "cookies", + type=str, + optional=True, + default=None, + description=_("The cookies to use for the HTTP request"), + ), + ], + ) + + def __init__( + self, + address: str, + methods: Optional[str] = "GET", + status_code: Optional[int] = 200, + timeout: Optional[int] = 60, + token: Optional[str] = None, + cookies: Optional[Dict[str, str]] = None, + **kwargs, + ): + """Initialize a HTTPSender.""" + try: + import aiohttp # noqa: F401 + except ImportError: + raise ImportError( + "aiohttp is required for HTTPSender, please install it with " + "`pip install aiohttp`" + ) + self._address = address + self._methods = methods + self._status_code = status_code + self._timeout = timeout + self._token = token + self._cookies = cookies + super().__init__(**kwargs) + + async def map(self, request_body: Dict) -> Dict: + """Send the request body to the specified address.""" + import aiohttp + + if self._methods in ["POST", "PUT"]: + req_kwargs = {"json": request_body} + else: + req_kwargs = {"params": request_body} + method = self._methods or "GET" + + headers = {} + if self._token: + headers["Authorization"] = f"Bearer {self._token}" + async with aiohttp.ClientSession( + headers=headers, + cookies=self._cookies, + timeout=aiohttp.ClientTimeout(total=self._timeout), + ) as session: + async with session.request( + method, + self._address, + raise_for_status=False, + **req_kwargs, + ) as response: + status_code = response.status + if status_code != self._status_code: + raise ValueError( + f"HTTP request failed with status code {status_code}" + ) + response_body = await response.json() + return response_body diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index b71258bc1..f3ee7a31e 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -1036,7 +1036,7 @@ async def map(self, request_body: CommonLLMHttpRequestBody) -> Dict[str, Any]: keys = self._key.split(".") for k in keys: dict_value = dict_value[k] - if isinstance(dict_value, dict): + if not isinstance(dict_value, dict): raise ValueError( f"Prefix key {self._key} is not a valid key of the request body" ) diff --git a/dbgpt/util/network/__init__.py b/dbgpt/util/network/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/util/network/_cli.py b/dbgpt/util/network/_cli.py new file mode 100644 index 000000000..d880bb070 --- /dev/null +++ b/dbgpt/util/network/_cli.py @@ -0,0 +1,289 @@ +import os +import socket +import ssl as py_ssl +import threading + +import click + +from ..console import CliLogger + +logger = CliLogger() + + +def forward_data(source, destination): + """Forward data from source to destination.""" + try: + while True: + data = source.recv(4096) + if b"" == data: + destination.sendall(data) + break + if not data: + break # no more data or connection closed + destination.sendall(data) + except Exception as e: + logger.error(f"Error forwarding data: {e}") + + +def handle_client( + client_socket, + remote_host: str, + remote_port: int, + is_ssl: bool = False, + http_proxy=None, +): + """Handle client connection. + + Create a connection to the remote host and port, and forward data between the + client and the remote host. + + Close the client socket and remote socket when all forwarding threads are done. + """ + # remote_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if http_proxy: + proxy_host, proxy_port = http_proxy + remote_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + remote_socket.connect((proxy_host, proxy_port)) + client_ip = client_socket.getpeername()[0] + scheme = "https" if is_ssl else "http" + connect_request = ( + f"CONNECT {remote_host}:{remote_port} HTTP/1.1\r\n" + f"Host: {remote_host}\r\n" + f"Connection: keep-alive\r\n" + f"X-Real-IP: {client_ip}\r\n" + f"X-Forwarded-For: {client_ip}\r\n" + f"X-Forwarded-Proto: {scheme}\r\n\r\n" + ) + logger.info(f"Sending connect request: {connect_request}") + remote_socket.sendall(connect_request.encode()) + + response = b"" + while True: + part = remote_socket.recv(4096) + response += part + if b"\r\n\r\n" in part: + break + + if b"200 Connection established" not in response: + logger.error("Failed to establish connection through proxy") + return + + else: + remote_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + remote_socket.connect((remote_host, remote_port)) + + if is_ssl: + # context = py_ssl.create_default_context(py_ssl.Purpose.CLIENT_AUTH) + context = py_ssl.create_default_context(py_ssl.Purpose.SERVER_AUTH) + # ssl_target_socket = py_ssl.wrap_socket(remote_socket) + ssl_target_socket = context.wrap_socket( + remote_socket, server_hostname=remote_host + ) + else: + ssl_target_socket = remote_socket + try: + # ssl_target_socket.connect((remote_host, remote_port)) + + # Forward data from client to server + client_to_server = threading.Thread( + target=forward_data, args=(client_socket, ssl_target_socket) + ) + client_to_server.start() + + # Forward data from server to client + server_to_client = threading.Thread( + target=forward_data, args=(ssl_target_socket, client_socket) + ) + server_to_client.start() + + client_to_server.join() + server_to_client.join() + except Exception as e: + logger.error(f"Error handling client connection: {e}") + finally: + # close the client and server sockets + client_socket.close() + ssl_target_socket.close() + + +@click.command(name="forward") +@click.option("--local-port", required=True, type=int, help="Local port to listen on.") +@click.option( + "--remote-host", required=True, type=str, help="Remote host to forward to." +) +@click.option( + "--remote-port", required=True, type=int, help="Remote port to forward to." +) +@click.option( + "--ssl", + is_flag=True, + help="Whether to use SSL for the connection to the remote host.", +) +@click.option( + "--tcp", + is_flag=True, + help="Whether to forward TCP traffic. " + "Default is HTTP. TCP has higher performance but not support proxies now.", +) +@click.option("--timeout", type=int, default=120, help="Timeout for the connection.") +@click.option( + "--proxies", + type=str, + help="HTTP proxy to use for forwarding requests. e.g. http://127.0.0.1:7890, " + "if not specified, try to read from environment variable http_proxy and " + "https_proxy.", +) +def start_forward( + local_port, + remote_host, + remote_port, + ssl: bool, + tcp: bool, + timeout: int, + proxies: str | None = None, +): + """Start a TCP/HTTP proxy server that forwards traffic from a local port to a remote + host and port, just for debugging purposes, please don't use it in production + environment. + """ + + """ + Example: + 1. Forward HTTP traffic: + + ``` + dbgpt net forward --local-port 5010 \ + --remote-host api.openai.com \ + --remote-port 443 \ + --ssl \ + --proxies http://127.0.0.1:7890 \ + --timeout 30 + ``` + Then you can set your environment variable `OPENAI_API_BASE` to + `http://127.0.0.1:5010/v1` + """ + if not tcp: + _start_http_forward(local_port, remote_host, remote_port, ssl, timeout, proxies) + else: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server: + server.bind(("0.0.0.0", local_port)) + server.listen(5) + logger.info( + f"[*] Listening on 0.0.0.0:{local_port}, forwarding to " + f"{remote_host}:{remote_port}" + ) + # http_proxy = ("127.0.0.1", 7890) + proxies = ( + proxies or os.environ.get("http_proxy") or os.environ.get("https_proxy") + ) + if proxies: + # proxies = "http://127.0.0.1:7890" + if proxies.startswith("http://") or proxies.startswith("https://"): + proxies = proxies.split("//")[1] + http_proxy = proxies.split(":")[0], int(proxies.split(":")[1]) + + while True: + client_socket, addr = server.accept() + logger.info(f"[*] Accepted connection from: {addr[0]}:{addr[1]}") + client_thread = threading.Thread( + target=handle_client, + args=(client_socket, remote_host, remote_port, ssl, http_proxy), + ) + client_thread.start() + + +def _start_http_forward( + local_port, remote_host, remote_port, ssl: bool, timeout, proxies: str | None = None +): + import httpx + import uvicorn + from fastapi import BackgroundTasks, FastAPI, Request, Response + from fastapi.responses import StreamingResponse + + app = FastAPI() + + @app.middleware("http") + async def forward_http_request(request: Request, call_next): + """Forward HTTP request to remote host.""" + nonlocal proxies + req_body = await request.body() + scheme = request.scope.get("scheme") + path = request.scope.get("path") + headers = dict(request.headers) + # Remove needless headers + stream_response = False + if request.method in ["POST", "PUT"]: + try: + import json + + stream_config = json.loads(req_body.decode("utf-8")) + stream_response = stream_config.get("stream", False) + except Exception: + pass + headers.pop("host", None) + if not proxies: + proxies = os.environ.get("http_proxy") or os.environ.get("https_proxy") + if proxies: + client_req = { + "proxies": { + "http://": proxies, + "https://": proxies, + } + } + else: + client_req = {} + if timeout: + client_req["timeout"] = timeout + + client = httpx.AsyncClient(**client_req) + # async with httpx.AsyncClient(**client_req) as client: + proxy_url = f"{remote_host}:{remote_port}" + if ssl: + scheme = "https" + new_url = ( + proxy_url if "://" in proxy_url else (scheme + "://" + proxy_url + path) + ) + req = client.build_request( + method=request.method, + url=new_url, + cookies=request.cookies, + content=req_body, + headers=headers, + params=request.query_params, + ) + has_connection = False + try: + logger.info(f"Forwarding request to {new_url}") + res = await client.send(req, stream=stream_response) + has_connection = True + if stream_response: + res_headers = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + } + background_tasks = BackgroundTasks() + background_tasks.add_task(client.aclose) + return StreamingResponse( + res.aiter_raw(), + headers=res_headers, + media_type=res.headers.get("content-type"), + ) + else: + return Response( + content=res.content, + status_code=res.status_code, + headers=dict(res.headers), + ) + except httpx.ConnectTimeout: + return Response( + content=f"Connection to remote server timeout", status_code=500 + ) + except Exception as e: + return Response(content=str(e), status_code=500) + finally: + if has_connection and not stream_response: + await client.aclose() + + uvicorn.run(app, host="0.0.0.0", port=local_port)