Skip to content

Commit

Permalink
Add traffic middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Sep 3, 2024
1 parent 2a7fbb2 commit 819abc0
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 49 deletions.
3 changes: 2 additions & 1 deletion log_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("uni-api")

logging.getLogger("httpx").setLevel(logging.CRITICAL)
logging.getLogger("httpx").setLevel(logging.CRITICAL)
logging.getLogger("watchfiles.main").setLevel(logging.CRITICAL)
162 changes: 114 additions & 48 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from contextlib import asynccontextmanager

from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, HTTPException, Depends
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials

from models import RequestModel, ImageGenerationRequest
from utils import error_handling_wrapper, get_all_models, post_all_models, load_config
from request import get_payload
from response import fetch_response, fetch_response_stream
from utils import error_handling_wrapper, post_all_models, load_config

from typing import List, Dict, Union
from urllib.parse import urlparse
Expand Down Expand Up @@ -42,36 +42,91 @@ async def lifespan(app: FastAPI):

app = FastAPI(lifespan=lifespan)

# from time import time
# from collections import defaultdict
# import asyncio

# class StatsMiddleware:
# def __init__(self):
# self.request_counts = defaultdict(int)
# self.request_times = defaultdict(float)
# self.ip_counts = defaultdict(lambda: defaultdict(int))
# self.lock = asyncio.Lock()

# async def __call__(self, request: Request, call_next):
# start_time = time()
# response = await call_next(request)
# process_time = time() - start_time

# endpoint = f"{request.method} {request.url.path}"
# client_ip = request.client.host

# async with self.lock:
# self.request_counts[endpoint] += 1
# self.request_times[endpoint] += process_time
# self.ip_counts[endpoint][client_ip] += 1

# return response
# # 创建 StatsMiddleware 实例
# stats_middleware = StatsMiddleware()

# # 添加 StatsMiddleware
# app.add_middleware(StatsMiddleware)
import asyncio
from time import time
from collections import defaultdict
from starlette.middleware.base import BaseHTTPMiddleware
from datetime import datetime
from datetime import timedelta
import json
import aiofiles

class StatsMiddleware(BaseHTTPMiddleware):
def __init__(self, app, exclude_paths=None, save_interval=3600, filename="stats.json"):
super().__init__(app)
self.request_counts = defaultdict(int)
self.request_times = defaultdict(float)
self.ip_counts = defaultdict(lambda: defaultdict(int))
self.request_arrivals = defaultdict(list)
self.lock = asyncio.Lock()
self.exclude_paths = set(exclude_paths or [])
self.save_interval = save_interval
self.filename = filename
self.last_save_time = time()

# 启动定期保存和清理任务
asyncio.create_task(self.periodic_save_and_cleanup())

async def dispatch(self, request: Request, call_next):
arrival_time = datetime.now()
start_time = time()
response = await call_next(request)
process_time = time() - start_time

endpoint = f"{request.method} {request.url.path}"
client_ip = request.client.host

if request.url.path not in self.exclude_paths:
async with self.lock:
self.request_counts[endpoint] += 1
self.request_times[endpoint] += process_time
self.ip_counts[endpoint][client_ip] += 1
self.request_arrivals[endpoint].append(arrival_time)

return response

async def periodic_save_and_cleanup(self):
while True:
await asyncio.sleep(self.save_interval)
await self.save_stats()
await self.cleanup_old_data()

async def save_stats(self):
current_time = time()
if current_time - self.last_save_time < self.save_interval:
return

async with self.lock:
stats = {
"request_counts": dict(self.request_counts),
"request_times": dict(self.request_times),
"ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()}
}

filename = self.filename
async with aiofiles.open(filename, mode='w') as f:
await f.write(json.dumps(stats, indent=2))

self.last_save_time = current_time
# print(f"Stats saved to {filename}")

async def cleanup_old_data(self):
# cutoff_time = datetime.now() - timedelta(seconds=30)
cutoff_time = datetime.now() - timedelta(hours=24)
async with self.lock:
for endpoint in list(self.request_arrivals.keys()):
self.request_arrivals[endpoint] = [
t for t in self.request_arrivals[endpoint] if t > cutoff_time
]
if not self.request_arrivals[endpoint]:
del self.request_arrivals[endpoint]
self.request_counts.pop(endpoint, None)
self.request_times.pop(endpoint, None)
self.ip_counts.pop(endpoint, None)

async def cleanup(self):
await self.save_stats()

# 配置 CORS 中间件
app.add_middleware(
Expand All @@ -82,6 +137,8 @@ async def lifespan(app: FastAPI):
allow_headers=["*"], # 允许所有头部字段
)

app.add_middleware(StatsMiddleware, exclude_paths=["/stats", "/generate-api-key"])

async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
url = provider['base_url']
parsed_url = urlparse(url)
Expand Down Expand Up @@ -233,6 +290,17 @@ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
return token

def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
api_list = app.state.api_list
token = credentials.credentials
if token not in api_list:
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
for api_key in app.state.api_keys_db:
if api_key['api'] == token:
if api_key.get('role') != "admin":
raise HTTPException(status_code=403, detail="Permission denied")
return token

@app.post("/v1/chat/completions")
async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
return await model_handler.request_model(request, token)
Expand All @@ -258,24 +326,22 @@ async def images_generations(

@app.get("/generate-api-key")
def generate_api_key():
api_key = "sk-" + secrets.token_urlsafe(32)
api_key = "sk-" + secrets.token_urlsafe(36)
return JSONResponse(content={"api_key": api_key})

# @app.get("/stats")
# async def get_stats(token: str = Depends(verify_api_key)):
# async with stats_middleware.lock:
# return {
# "request_counts": dict(stats_middleware.request_counts),
# "average_request_times": {
# endpoint: total_time / count
# for endpoint, total_time in stats_middleware.request_times.items()
# for count in [stats_middleware.request_counts[endpoint]]
# },
# "ip_counts": {
# endpoint: dict(ips)
# for endpoint, ips in stats_middleware.ip_counts.items()
# }
# }
@app.get("/stats")
async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
middleware = app.middleware_stack.app
if isinstance(middleware, StatsMiddleware):
async with middleware.lock:
stats = {
"request_counts": dict(middleware.request_counts),
"request_times": dict(middleware.request_times),
"ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()}
}
return JSONResponse(content=stats)
return {"error": "StatsMiddleware not found"}

# async def on_fetch(request, env):
# import asgi
Expand Down

0 comments on commit 819abc0

Please sign in to comment.