131 lines
4.2 KiB
Python
131 lines
4.2 KiB
Python
from quart import Quart, websocket
|
|
import asyncio
|
|
import json
|
|
import time
|
|
from typing import Dict, Set, Optional, Callable, Any
|
|
from dataclasses import dataclass
|
|
from functools import partial
|
|
|
|
|
|
@dataclass
|
|
class WebSocketConfig:
|
|
"""WebSocket配置类"""
|
|
heartbeat_timeout: float = 45.0 # 心跳超时时间(秒)
|
|
heartbeat_interval: float = 15.0 # 心跳检查间隔(秒)
|
|
on_connect: Optional[Callable[[str], Any]] = None # 连接回调
|
|
on_disconnect: Optional[Callable[[str], Any]] = None # 断开回调
|
|
on_message: Optional[Callable[[str, dict], Any]] = None # 消息回调
|
|
|
|
|
|
def setup_websocket(
|
|
app: Quart,
|
|
url_prefix: str = '/api/ping/ws',
|
|
config: Optional[WebSocketConfig] = None
|
|
) -> None:
|
|
"""
|
|
设置WebSocket服务
|
|
|
|
Args:
|
|
app: Quart应用实例
|
|
url_prefix: WebSocket路由前缀
|
|
config: WebSocket配置
|
|
"""
|
|
if config is None:
|
|
config = WebSocketConfig()
|
|
|
|
# 存储活跃连接
|
|
connected_clients: Set[str] = set()
|
|
last_heartbeat: Dict[str, float] = {}
|
|
|
|
async def heartbeat_check() -> None:
|
|
"""心跳检查任务"""
|
|
while True:
|
|
current_time = time.time()
|
|
disconnected = []
|
|
for client_id in connected_clients:
|
|
if current_time - last_heartbeat.get(client_id, 0) > config.heartbeat_timeout:
|
|
disconnected.append(client_id)
|
|
|
|
for client_id in disconnected:
|
|
connected_clients.remove(client_id)
|
|
last_heartbeat.pop(client_id, None)
|
|
if config.on_disconnect:
|
|
await config.on_disconnect(client_id)
|
|
print(f"Client {client_id} timed out")
|
|
|
|
await asyncio.sleep(config.heartbeat_interval)
|
|
|
|
@app.websocket(url_prefix)
|
|
async def ws():
|
|
"""WebSocket路由处理函数"""
|
|
try:
|
|
client_id = str(id(websocket))
|
|
connected_clients.add(client_id)
|
|
last_heartbeat[client_id] = time.time()
|
|
|
|
if config.on_connect:
|
|
await config.on_connect(client_id)
|
|
|
|
print(f"Client {client_id} connected")
|
|
|
|
# 发送连接成功消息
|
|
await websocket.send(json.dumps({
|
|
"type": "connected",
|
|
"message": "Successfully connected to server",
|
|
"client_id": client_id
|
|
}))
|
|
|
|
while True:
|
|
data = await websocket.receive()
|
|
try:
|
|
message = json.loads(data)
|
|
|
|
# 处理心跳
|
|
if message.get("type") == "ping":
|
|
last_heartbeat[client_id] = time.time()
|
|
await websocket.send(json.dumps({
|
|
"type": "pong",
|
|
"timestamp": time.time()
|
|
}))
|
|
|
|
# 处理其他消息
|
|
elif config.on_message:
|
|
await config.on_message(client_id, message)
|
|
|
|
except json.JSONDecodeError:
|
|
await websocket.send(json.dumps({
|
|
"type": "error",
|
|
"message": "Invalid JSON format"
|
|
}))
|
|
|
|
except asyncio.CancelledError:
|
|
print(f"Client {client_id} disconnected")
|
|
if config.on_disconnect:
|
|
await config.on_disconnect(client_id)
|
|
connected_clients.discard(client_id)
|
|
last_heartbeat.pop(client_id, None)
|
|
raise
|
|
|
|
except Exception as e:
|
|
print(f"Error handling client {client_id}: {str(e)}")
|
|
if config.on_disconnect:
|
|
await config.on_disconnect(client_id)
|
|
connected_clients.discard(client_id)
|
|
last_heartbeat.pop(client_id, None)
|
|
await websocket.close(1011)
|
|
|
|
@app.before_serving
|
|
async def startup():
|
|
"""服务启动时启动心跳检查"""
|
|
app.heartbeat_task = asyncio.create_task(heartbeat_check())
|
|
|
|
@app.after_serving
|
|
async def shutdown():
|
|
"""服务关闭时清理"""
|
|
if hasattr(app, 'heartbeat_task'):
|
|
app.heartbeat_task.cancel()
|
|
try:
|
|
await app.heartbeat_task
|
|
except asyncio.CancelledError:
|
|
pass
|