2025-03-05 11:40:19 +08:00

131 lines
4.2 KiB
Python

from quart import Quart, websocket
import asyncio
import json
import time
from typing import Dict, Set, Optional, Callable, Any
from dataclasses import dataclass
from functools import partial
@dataclass
class WebSocketConfig:
"""WebSocket配置类"""
heartbeat_timeout: float = 45.0 # 心跳超时时间(秒)
heartbeat_interval: float = 15.0 # 心跳检查间隔(秒)
on_connect: Optional[Callable[[str], Any]] = None # 连接回调
on_disconnect: Optional[Callable[[str], Any]] = None # 断开回调
on_message: Optional[Callable[[str, dict], Any]] = None # 消息回调
def setup_websocket(
app: Quart,
url_prefix: str = '/api/ping/ws',
config: Optional[WebSocketConfig] = None
) -> None:
"""
设置WebSocket服务
Args:
app: Quart应用实例
url_prefix: WebSocket路由前缀
config: WebSocket配置
"""
if config is None:
config = WebSocketConfig()
# 存储活跃连接
connected_clients: Set[str] = set()
last_heartbeat: Dict[str, float] = {}
async def heartbeat_check() -> None:
"""心跳检查任务"""
while True:
current_time = time.time()
disconnected = []
for client_id in connected_clients:
if current_time - last_heartbeat.get(client_id, 0) > config.heartbeat_timeout:
disconnected.append(client_id)
for client_id in disconnected:
connected_clients.remove(client_id)
last_heartbeat.pop(client_id, None)
if config.on_disconnect:
await config.on_disconnect(client_id)
print(f"Client {client_id} timed out")
await asyncio.sleep(config.heartbeat_interval)
@app.websocket(url_prefix)
async def ws():
"""WebSocket路由处理函数"""
try:
client_id = str(id(websocket))
connected_clients.add(client_id)
last_heartbeat[client_id] = time.time()
if config.on_connect:
await config.on_connect(client_id)
print(f"Client {client_id} connected")
# 发送连接成功消息
await websocket.send(json.dumps({
"type": "connected",
"message": "Successfully connected to server",
"client_id": client_id
}))
while True:
data = await websocket.receive()
try:
message = json.loads(data)
# 处理心跳
if message.get("type") == "ping":
last_heartbeat[client_id] = time.time()
await websocket.send(json.dumps({
"type": "pong",
"timestamp": time.time()
}))
# 处理其他消息
elif config.on_message:
await config.on_message(client_id, message)
except json.JSONDecodeError:
await websocket.send(json.dumps({
"type": "error",
"message": "Invalid JSON format"
}))
except asyncio.CancelledError:
print(f"Client {client_id} disconnected")
if config.on_disconnect:
await config.on_disconnect(client_id)
connected_clients.discard(client_id)
last_heartbeat.pop(client_id, None)
raise
except Exception as e:
print(f"Error handling client {client_id}: {str(e)}")
if config.on_disconnect:
await config.on_disconnect(client_id)
connected_clients.discard(client_id)
last_heartbeat.pop(client_id, None)
await websocket.close(1011)
@app.before_serving
async def startup():
"""服务启动时启动心跳检查"""
app.heartbeat_task = asyncio.create_task(heartbeat_check())
@app.after_serving
async def shutdown():
"""服务关闭时清理"""
if hasattr(app, 'heartbeat_task'):
app.heartbeat_task.cancel()
try:
await app.heartbeat_task
except asyncio.CancelledError:
pass