diff options
Diffstat (limited to 'web/websocket/hub.go')
| -rw-r--r-- | web/websocket/hub.go | 535 |
1 files changed, 256 insertions, 279 deletions
diff --git a/web/websocket/hub.go b/web/websocket/hub.go index 1455d1fa..29ba384e 100644 --- a/web/websocket/hub.go +++ b/web/websocket/hub.go @@ -1,402 +1,379 @@ -// Package websocket provides WebSocket hub for real-time updates and notifications. +// Package websocket provides a WebSocket hub for real-time updates and notifications. package websocket import ( "context" "encoding/json" - "runtime" "sync" "time" "github.com/mhsanaei/3x-ui/v2/logger" ) -// MessageType represents the type of WebSocket message +// MessageType identifies the kind of WebSocket message. type MessageType string const ( - MessageTypeStatus MessageType = "status" // Server status update - MessageTypeTraffic MessageType = "traffic" // Traffic statistics update - MessageTypeInbounds MessageType = "inbounds" // Inbounds list update - MessageTypeNotification MessageType = "notification" // System notification - MessageTypeXrayState MessageType = "xray_state" // Xray state change - MessageTypeOutbounds MessageType = "outbounds" // Outbounds list update - MessageTypeInvalidate MessageType = "invalidate" // Lightweight signal telling frontend to re-fetch data via REST + MessageTypeStatus MessageType = "status" + MessageTypeTraffic MessageType = "traffic" + MessageTypeInbounds MessageType = "inbounds" + MessageTypeOutbounds MessageType = "outbounds" + MessageTypeNotification MessageType = "notification" + MessageTypeXrayState MessageType = "xray_state" + // MessageTypeClientStats carries absolute traffic counters for the clients + // that had activity in the latest collection window. Frontend applies these + // in-place — far smaller than re-broadcasting the full inbound list and + // scales to 10k+ clients without falling back to REST. + MessageTypeClientStats MessageType = "client_stats" + MessageTypeInvalidate MessageType = "invalidate" // Tells frontend to re-fetch via REST (last-resort). + + // maxMessageSize caps the WebSocket payload. Beyond this the hub sends a + // lightweight invalidate signal and the frontend re-fetches via REST. + // 10MB lets typical 2k–8k-client deployments push directly via WS (low + // latency); larger installs fall back to invalidate. + maxMessageSize = 10 * 1024 * 1024 // 10MB + + enqueueTimeout = 100 * time.Millisecond + clientSendQueue = 512 // ~50s of buffering for a momentarily slow browser. + hubBroadcastQueue = 2048 // Headroom for cron-storm + admin-mutation bursts. + hubControlQueue = 64 // Backlog for register/unregister bursts (page reloads, disconnect storms). + + // minBroadcastInterval throttles per-type broadcasts so cron storms or + // rapid mutations cannot drown the hub. Bursts within the interval are + // dropped (not coalesced); the next broadcast outside the window delivers + // the latest state. Only message types in throttledMessageTypes are gated — + // heartbeat and real-time signals (status, traffic, client_stats, + // notification, xray_state, invalidate) bypass this so they are never delayed. + minBroadcastInterval = 250 * time.Millisecond + + // hubRestartAttempts caps panic-recovery restarts. After this many + // consecutive failures we stop trying and log; the panel keeps running + // (frontend falls back to REST polling) and the operator can investigate. + hubRestartAttempts = 3 ) -// Message represents a WebSocket message +// NewClient builds a Client ready for hub registration. +func NewClient(id string) *Client { + return &Client{ + ID: id, + Send: make(chan []byte, clientSendQueue), + } +} + +// Message is the wire format sent to clients. type Message struct { Type MessageType `json:"type"` Payload any `json:"payload"` Time int64 `json:"time"` } -// Client represents a WebSocket client connection +// Client represents a single WebSocket connection. type Client struct { ID string Send chan []byte - Hub *Hub - Topics map[MessageType]bool // Subscribed topics - closeOnce sync.Once // Ensures Send channel is closed exactly once + closeOnce sync.Once } -// Hub maintains the set of active clients and broadcasts messages to them +// Hub fan-outs messages to all connected clients. type Hub struct { - // Registered clients - clients map[*Client]bool - - // Inbound messages from clients - broadcast chan []byte - - // Register requests from clients - register chan *Client - - // Unregister requests from clients + clients map[*Client]struct{} + broadcast chan []byte + register chan *Client unregister chan *Client + mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc - // Mutex for thread-safe operations - mu sync.RWMutex - - // Context for graceful shutdown - ctx context.Context - cancel context.CancelFunc - - // Worker pool for parallel broadcasting - workerPoolSize int + throttleMu sync.Mutex + lastBroadcast map[MessageType]time.Time } -// NewHub creates a new WebSocket hub +// NewHub creates a hub. Call Run in a goroutine to start its event loop. func NewHub() *Hub { ctx, cancel := context.WithCancel(context.Background()) - - // Calculate optimal worker pool size (CPU cores * 2, but max 100) - workerPoolSize := runtime.NumCPU() * 2 - if workerPoolSize > 100 { - workerPoolSize = 100 - } - if workerPoolSize < 10 { - workerPoolSize = 10 - } - return &Hub{ - clients: make(map[*Client]bool), - broadcast: make(chan []byte, 2048), // Increased from 256 to 2048 for high load - register: make(chan *Client, 100), // Buffered channel for fast registration - unregister: make(chan *Client, 100), // Buffered channel for fast unregistration - ctx: ctx, - cancel: cancel, - workerPoolSize: workerPoolSize, + clients: make(map[*Client]struct{}), + broadcast: make(chan []byte, hubBroadcastQueue), + register: make(chan *Client, hubControlQueue), + unregister: make(chan *Client, hubControlQueue), + ctx: ctx, + cancel: cancel, + lastBroadcast: make(map[MessageType]time.Time), } } -// Run starts the hub's main loop +// throttledMessageTypes is the explicit allow-list of message types subject to +// the per-type rate limit. Everything else (status, traffic, client_stats, +// notification, xray_state, invalidate) is heartbeat- or signal-class and must +// not be delayed. Keeping the set explicit (vs. an exclusion list) makes the +// intent obvious when new message types are added — by default they bypass. +var throttledMessageTypes = map[MessageType]struct{}{ + MessageTypeInbounds: {}, + MessageTypeOutbounds: {}, +} + +// shouldThrottle returns true if a broadcast of msgType is rate-limited and +// happened within minBroadcastInterval of the previous one. Only message types +// in throttledMessageTypes are gated. +func (h *Hub) shouldThrottle(msgType MessageType) bool { + if _, gated := throttledMessageTypes[msgType]; !gated { + return false + } + h.throttleMu.Lock() + defer h.throttleMu.Unlock() + now := time.Now() + if last, ok := h.lastBroadcast[msgType]; ok && now.Sub(last) < minBroadcastInterval { + return true + } + h.lastBroadcast[msgType] = now + return false +} + +// Run drives the hub. The inner loop is wrapped in a panic-recovery harness +// that retries up to hubRestartAttempts times with backoff so a transient +// panic doesn't permanently kill real-time updates for commercial deployments. +// After the cap, the hub stays down and the frontend falls back to REST polling. func (h *Hub) Run() { + for attempt := 0; attempt < hubRestartAttempts; attempt++ { + stopped := h.runOnce() + if stopped { + return + } + if attempt < hubRestartAttempts-1 { + wait := time.Duration(1<<attempt) * time.Second // 1s, 2s, 4s + logger.Errorf("WebSocket hub crashed, restarting in %s (%d/%d)", wait, attempt+1, hubRestartAttempts-1) + select { + case <-time.After(wait): + case <-h.ctx.Done(): + return + } + } + } + logger.Error("WebSocket hub stopped after exhausting restart attempts") +} + +// runOnce drives the event loop once and returns true if the hub stopped +// cleanly (context cancelled). On panic, recover logs and returns false so +// Run can decide whether to retry. +func (h *Hub) runOnce() (stopped bool) { defer func() { if r := recover(); r != nil { - logger.Error("WebSocket hub panic recovered:", r) - // Restart the hub loop - go h.Run() + logger.Errorf("WebSocket hub panic recovered: %v", r) + stopped = false } }() for { select { case <-h.ctx.Done(): - // Graceful shutdown: close all clients - h.mu.Lock() - for client := range h.clients { - client.closeOnce.Do(func() { - close(client.Send) - }) - } - h.clients = make(map[*Client]bool) - h.mu.Unlock() - logger.Info("WebSocket hub stopped gracefully") - return + h.shutdown() + return true - case client := <-h.register: - if client == nil { + case c := <-h.register: + if c == nil { continue } h.mu.Lock() - h.clients[client] = true - count := len(h.clients) + h.clients[c] = struct{}{} + n := len(h.clients) h.mu.Unlock() - logger.Debugf("WebSocket client connected: %s (total: %d)", client.ID, count) + logger.Debugf("WebSocket client connected: %s (total: %d)", c.ID, n) - case client := <-h.unregister: - if client == nil { + case c := <-h.unregister: + if c == nil { continue } - h.mu.Lock() - if _, ok := h.clients[client]; ok { - delete(h.clients, client) - client.closeOnce.Do(func() { - close(client.Send) - }) - } - count := len(h.clients) - h.mu.Unlock() - logger.Debugf("WebSocket client disconnected: %s (total: %d)", client.ID, count) - - case message := <-h.broadcast: - if message == nil { - continue - } - // Optimization: quickly copy client list and release lock - h.mu.RLock() - clientCount := len(h.clients) - if clientCount == 0 { - h.mu.RUnlock() - continue - } - - // Pre-allocate memory for client list - clients := make([]*Client, 0, clientCount) - for client := range h.clients { - clients = append(clients, client) - } - h.mu.RUnlock() + h.removeClient(c) - // Parallel broadcast using worker pool - h.broadcastParallel(clients, message) + case msg := <-h.broadcast: + h.fanout(msg) } } } -// broadcastParallel sends message to all clients in parallel for maximum performance -func (h *Hub) broadcastParallel(clients []*Client, message []byte) { - if len(clients) == 0 { - return - } - - // For small number of clients, use simple parallel sending - if len(clients) < h.workerPoolSize { - var wg sync.WaitGroup - for _, client := range clients { - wg.Add(1) - go func(c *Client) { - defer wg.Done() - defer func() { - if r := recover(); r != nil { - // Channel may be closed, safely ignore - logger.Debugf("WebSocket broadcast panic recovered for client %s: %v", c.ID, r) - } - }() - select { - case c.Send <- message: - default: - // Client's send buffer is full, disconnect - logger.Debugf("WebSocket client %s send buffer full, disconnecting", c.ID) - h.Unregister(c) - } - }(client) - } - wg.Wait() - return +// shutdown closes all client send channels and clears the registry. +func (h *Hub) shutdown() { + h.mu.Lock() + for c := range h.clients { + c.closeOnce.Do(func() { close(c.Send) }) } + h.clients = make(map[*Client]struct{}) + h.mu.Unlock() + logger.Info("WebSocket hub stopped") +} - // For large number of clients, use worker pool for optimal performance - clientChan := make(chan *Client, len(clients)) - for _, client := range clients { - clientChan <- client +// removeClient deletes a client and closes its send channel exactly once. +func (h *Hub) removeClient(c *Client) { + h.mu.Lock() + if _, ok := h.clients[c]; ok { + delete(h.clients, c) + c.closeOnce.Do(func() { close(c.Send) }) } - close(clientChan) - - // Use a local WaitGroup to avoid blocking hub shutdown - var wg sync.WaitGroup - wg.Add(h.workerPoolSize) - for i := 0; i < h.workerPoolSize; i++ { - go func() { - defer wg.Done() - for client := range clientChan { - func() { - defer func() { - if r := recover(); r != nil { - // Channel may be closed, safely ignore - logger.Debugf("WebSocket broadcast panic recovered for client %s: %v", client.ID, r) - } - }() - select { - case client.Send <- message: - default: - // Client's send buffer is full, disconnect - logger.Debugf("WebSocket client %s send buffer full, disconnecting", client.ID) - h.Unregister(client) - } - }() - } - }() - } - - // Wait for all workers to finish - wg.Wait() + n := len(h.clients) + h.mu.Unlock() + logger.Debugf("WebSocket client disconnected: %s (total: %d)", c.ID, n) } -// Broadcast sends a message to all connected clients -func (h *Hub) Broadcast(messageType MessageType, payload any) { - if h == nil { +// fanout delivers msg to every client. Each send is non-blocking — a client +// whose buffer is full is collected for direct removal at the end. We do NOT +// route slow-client unregistrations through the unregister channel: under +// burst load (panel restart, network blip) that channel can fill up while the +// hub itself is the consumer, causing a self-deadlock. +func (h *Hub) fanout(msg []byte) { + if msg == nil { return } - if payload == nil { - logger.Warning("Attempted to broadcast nil payload") + h.mu.RLock() + if len(h.clients) == 0 { + h.mu.RUnlock() return } - - // Skip all work if no clients are connected - if h.GetClientCount() == 0 { - return + targets := make([]*Client, 0, len(h.clients)) + for c := range h.clients { + targets = append(targets, c) } + h.mu.RUnlock() - msg := Message{ - Type: messageType, - Payload: payload, - Time: getCurrentTimestamp(), + var dead []*Client + for _, c := range targets { + if !trySend(c, msg) { + dead = append(dead, c) + } } - data, err := json.Marshal(msg) - if err != nil { - logger.Error("Failed to marshal WebSocket message:", err) + if len(dead) == 0 { return } - - // If message exceeds size limit, send a lightweight invalidate notification - // instead of dropping it entirely — the frontend will re-fetch via REST API - const maxMessageSize = 10 * 1024 * 1024 // 10MB - if len(data) > maxMessageSize { - logger.Debugf("WebSocket message too large (%d bytes) for type %s, sending invalidate signal", len(data), messageType) - h.broadcastInvalidate(messageType) - return + h.mu.Lock() + for _, c := range dead { + if _, ok := h.clients[c]; ok { + delete(h.clients, c) + c.closeOnce.Do(func() { close(c.Send) }) + logger.Debugf("WebSocket client %s send buffer full, disconnected", c.ID) + } } + h.mu.Unlock() +} - // Non-blocking send with timeout to prevent delays +// trySend performs a non-blocking write to the client's Send channel. +// Returns false if the client should be evicted (full buffer or closed channel). +// A defer-recover guards against the rare race where the channel was closed +// concurrently — sending on a closed channel always panics, even with select+default. +func trySend(c *Client, msg []byte) (ok bool) { + defer func() { + if r := recover(); r != nil { + ok = false + } + }() select { - case h.broadcast <- data: - case <-time.After(100 * time.Millisecond): - logger.Warning("WebSocket broadcast channel is full, dropping message") - case <-h.ctx.Done(): - // Hub is shutting down + case c.Send <- msg: + return true + default: + return false } } -// BroadcastToTopic sends a message only to clients subscribed to the specific topic -func (h *Hub) BroadcastToTopic(messageType MessageType, payload any) { - if h == nil { - return - } - if payload == nil { - logger.Warning("Attempted to broadcast nil payload to topic") +// Broadcast serializes payload and queues it for delivery to all clients. +// If the serialized message exceeds maxMessageSize, an invalidate signal is +// queued instead so the frontend re-fetches via REST. Broadcasts of throttled +// message types (see throttledMessageTypes) within minBroadcastInterval of +// the previous one are dropped — the next legitimate mutation will push the +// fresh state. +func (h *Hub) Broadcast(messageType MessageType, payload any) { + if h == nil || payload == nil || h.GetClientCount() == 0 { return } - - // Skip all work if no clients are connected - if h.GetClientCount() == 0 { + if h.shouldThrottle(messageType) { return } - - msg := Message{ + data, err := json.Marshal(Message{ Type: messageType, Payload: payload, - Time: getCurrentTimestamp(), - } - - data, err := json.Marshal(msg) + Time: time.Now().UnixMilli(), + }) if err != nil { - logger.Error("Failed to marshal WebSocket message:", err) + logger.Error("WebSocket marshal failed:", err) return } - - // If message exceeds size limit, send a lightweight invalidate notification - const maxMessageSize = 10 * 1024 * 1024 // 10MB if len(data) > maxMessageSize { - logger.Debugf("WebSocket message too large (%d bytes) for type %s, sending invalidate signal", len(data), messageType) + logger.Debugf("WebSocket payload %d bytes exceeds limit, sending invalidate for %s", len(data), messageType) h.broadcastInvalidate(messageType) return } + h.enqueue(data) +} - h.mu.RLock() - // Filter clients by topics and quickly release lock - subscribedClients := make([]*Client, 0) - for client := range h.clients { - if len(client.Topics) == 0 || client.Topics[messageType] { - subscribedClients = append(subscribedClients, client) - } +// broadcastInvalidate queues a lightweight signal telling clients to re-fetch +// the named data type via REST. +func (h *Hub) broadcastInvalidate(originalType MessageType) { + data, err := json.Marshal(Message{ + Type: MessageTypeInvalidate, + Payload: map[string]string{"type": string(originalType)}, + Time: time.Now().UnixMilli(), + }) + if err != nil { + logger.Error("WebSocket invalidate marshal failed:", err) + return } - h.mu.RUnlock() + h.enqueue(data) +} - // Parallel send to subscribed clients - if len(subscribedClients) > 0 { - h.broadcastParallel(subscribedClients, data) +// enqueue submits raw bytes to the broadcast channel. Dropped on backpressure +// (channel full for >100ms) or shutdown. +func (h *Hub) enqueue(data []byte) { + select { + case h.broadcast <- data: + case <-time.After(enqueueTimeout): + logger.Warning("WebSocket broadcast channel full, dropping message") + case <-h.ctx.Done(): } } -// GetClientCount returns the number of connected clients +// GetClientCount returns the number of connected clients. func (h *Hub) GetClientCount() int { + if h == nil { + return 0 + } h.mu.RLock() defer h.mu.RUnlock() return len(h.clients) } -// Register registers a new client with the hub -func (h *Hub) Register(client *Client) { - if h == nil || client == nil { +// Register adds a client to the hub. +func (h *Hub) Register(c *Client) { + if h == nil || c == nil { return } select { - case h.register <- client: + case h.register <- c: case <-h.ctx.Done(): - // Hub is shutting down } } -// Unregister unregisters a client from the hub -func (h *Hub) Unregister(client *Client) { - if h == nil || client == nil { +// Unregister removes a client from the hub. Fast path queues for the hub +// goroutine; if the channel is saturated (disconnect storm) we fall back +// to a direct removal under the write lock so dead clients aren't left in +// the registry waiting for their Send buffer to fill (minutes of wasted +// fanout work at low broadcast rates). +// +// Direct removal is safe from any caller: external goroutines (read/write +// pumps) hold no hub locks, and the hub goroutine itself never holds h.mu +// when it calls Unregister — fanout releases its RLock before per-client +// sends, so we can't self-deadlock here. +func (h *Hub) Unregister(c *Client) { + if h == nil || c == nil { return } select { - case h.unregister <- client: - case <-h.ctx.Done(): - // Hub is shutting down + case h.unregister <- c: + default: + h.removeClient(c) } } -// Stop gracefully stops the hub and closes all connections +// Stop signals the hub to shut down and close all client connections. func (h *Hub) Stop() { - if h == nil { - return - } - if h.cancel != nil { + if h != nil && h.cancel != nil { h.cancel() } } - -// broadcastInvalidate sends a lightweight invalidate message to all clients, -// telling them to re-fetch the specified data type via REST API. -// This is used when the full payload exceeds the WebSocket message size limit. -func (h *Hub) broadcastInvalidate(originalType MessageType) { - msg := Message{ - Type: MessageTypeInvalidate, - Payload: map[string]string{"type": string(originalType)}, - Time: getCurrentTimestamp(), - } - - data, err := json.Marshal(msg) - if err != nil { - logger.Error("Failed to marshal invalidate message:", err) - return - } - - // Non-blocking send with timeout - select { - case h.broadcast <- data: - case <-time.After(100 * time.Millisecond): - logger.Warning("WebSocket broadcast channel is full, dropping invalidate message") - case <-h.ctx.Done(): - } -} - -// getCurrentTimestamp returns current Unix timestamp in milliseconds -func getCurrentTimestamp() int64 { - return time.Now().UnixMilli() -} |
