from quart import Quart, websocket import asyncio import json import time from typing import Dict, Set, Optional, Callable, Any from dataclasses import dataclass from functools import partial @dataclass class WebSocketConfig: """WebSocket配置类""" heartbeat_timeout: float = 45.0 # 心跳超时时间(秒) heartbeat_interval: float = 15.0 # 心跳检查间隔(秒) on_connect: Optional[Callable[[str], Any]] = None # 连接回调 on_disconnect: Optional[Callable[[str], Any]] = None # 断开回调 on_message: Optional[Callable[[str, dict], Any]] = None # 消息回调 def setup_websocket( app: Quart, url_prefix: str = '/api/ping/ws', config: Optional[WebSocketConfig] = None ) -> None: """ 设置WebSocket服务 Args: app: Quart应用实例 url_prefix: WebSocket路由前缀 config: WebSocket配置 """ if config is None: config = WebSocketConfig() # 存储活跃连接 connected_clients: Set[str] = set() last_heartbeat: Dict[str, float] = {} async def heartbeat_check() -> None: """心跳检查任务""" while True: current_time = time.time() disconnected = [] for client_id in connected_clients: if current_time - last_heartbeat.get(client_id, 0) > config.heartbeat_timeout: disconnected.append(client_id) for client_id in disconnected: connected_clients.remove(client_id) last_heartbeat.pop(client_id, None) if config.on_disconnect: await config.on_disconnect(client_id) print(f"Client {client_id} timed out") await asyncio.sleep(config.heartbeat_interval) @app.websocket(url_prefix) async def ws(): """WebSocket路由处理函数""" try: client_id = str(id(websocket)) connected_clients.add(client_id) last_heartbeat[client_id] = time.time() if config.on_connect: await config.on_connect(client_id) print(f"Client {client_id} connected") # 发送连接成功消息 await websocket.send(json.dumps({ "type": "connected", "message": "Successfully connected to server", "client_id": client_id })) while True: data = await websocket.receive() try: message = json.loads(data) # 处理心跳 if message.get("type") == "ping": last_heartbeat[client_id] = time.time() await websocket.send(json.dumps({ "type": "pong", "timestamp": time.time() })) # 处理其他消息 elif config.on_message: await config.on_message(client_id, message) except json.JSONDecodeError: await websocket.send(json.dumps({ "type": "error", "message": "Invalid JSON format" })) except asyncio.CancelledError: print(f"Client {client_id} disconnected") if config.on_disconnect: await config.on_disconnect(client_id) connected_clients.discard(client_id) last_heartbeat.pop(client_id, None) raise except Exception as e: print(f"Error handling client {client_id}: {str(e)}") if config.on_disconnect: await config.on_disconnect(client_id) connected_clients.discard(client_id) last_heartbeat.pop(client_id, None) await websocket.close(1011) @app.before_serving async def startup(): """服务启动时启动心跳检查""" app.heartbeat_task = asyncio.create_task(heartbeat_check()) @app.after_serving async def shutdown(): """服务关闭时清理""" if hasattr(app, 'heartbeat_task'): app.heartbeat_task.cancel() try: await app.heartbeat_task except asyncio.CancelledError: pass