Welcome to mirror list, hosted at ThFree Co, Russian Federation.

websocket.go « controller « web - github.com/MHSanaei/3x-ui.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 0ad5c8459a26187c7a0d6ec30788e229fba99447 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
package controller

import (
	"net/http"
	"strings"
	"time"

	"github.com/google/uuid"
	"github.com/mhsanaei/3x-ui/v2/logger"
	"github.com/mhsanaei/3x-ui/v2/util/common"
	"github.com/mhsanaei/3x-ui/v2/web/session"
	"github.com/mhsanaei/3x-ui/v2/web/websocket"

	"github.com/gin-gonic/gin"
	ws "github.com/gorilla/websocket"
)

const (
	// Time allowed to write a message to the peer
	writeWait = 10 * time.Second

	// Time allowed to read the next pong message from the peer
	pongWait = 60 * time.Second

	// Send pings to peer with this period (must be less than pongWait)
	pingPeriod = (pongWait * 9) / 10

	// Maximum message size allowed from peer
	maxMessageSize = 512
)

var upgrader = ws.Upgrader{
	ReadBufferSize:  4096, // Increased from 1024 for better performance
	WriteBufferSize: 4096, // Increased from 1024 for better performance
	CheckOrigin: func(r *http.Request) bool {
		// Check origin for security
		origin := r.Header.Get("Origin")
		if origin == "" {
			// Allow connections without Origin header (same-origin requests)
			return true
		}
		// Get the host from the request
		host := r.Host
		// Extract scheme and host from origin
		originURL := origin
		// Simple check: origin should match the request host
		// This prevents cross-origin WebSocket hijacking
		if strings.HasPrefix(originURL, "http://") || strings.HasPrefix(originURL, "https://") {
			// Extract host from origin
			originHost := strings.TrimPrefix(strings.TrimPrefix(originURL, "http://"), "https://")
			if idx := strings.Index(originHost, "/"); idx != -1 {
				originHost = originHost[:idx]
			}
			if idx := strings.Index(originHost, ":"); idx != -1 {
				originHost = originHost[:idx]
			}
			// Compare hosts (without port)
			requestHost := host
			if idx := strings.Index(requestHost, ":"); idx != -1 {
				requestHost = requestHost[:idx]
			}
			return originHost == requestHost || originHost == "" || requestHost == ""
		}
		return false
	},
}

// WebSocketController handles WebSocket connections for real-time updates
type WebSocketController struct {
	BaseController
	hub *websocket.Hub
}

// NewWebSocketController creates a new WebSocket controller
func NewWebSocketController(hub *websocket.Hub) *WebSocketController {
	return &WebSocketController{
		hub: hub,
	}
}

// HandleWebSocket handles WebSocket connections
func (w *WebSocketController) HandleWebSocket(c *gin.Context) {
	// Check authentication
	if !session.IsLogin(c) {
		logger.Warningf("Unauthorized WebSocket connection attempt from %s", getRemoteIp(c))
		c.AbortWithStatus(http.StatusUnauthorized)
		return
	}

	// Upgrade connection to WebSocket
	conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
	if err != nil {
		logger.Error("Failed to upgrade WebSocket connection:", err)
		return
	}

	// Create client
	clientID := uuid.New().String()
	client := &websocket.Client{
		ID:     clientID,
		Hub:    w.hub,
		Send:   make(chan []byte, 512), // Increased from 256 to 512 to prevent overflow
		Topics: make(map[websocket.MessageType]bool),
	}

	// Register client
	w.hub.Register(client)
	logger.Debugf("WebSocket client %s registered from %s", clientID, getRemoteIp(c))

	// Start goroutines for reading and writing
	go w.writePump(client, conn)
	go w.readPump(client, conn)
}

// readPump pumps messages from the WebSocket connection to the hub
func (w *WebSocketController) readPump(client *websocket.Client, conn *ws.Conn) {
	defer func() {
		if r := common.Recover("WebSocket readPump panic"); r != nil {
			logger.Error("WebSocket readPump panic recovered:", r)
		}
		w.hub.Unregister(client)
		conn.Close()
	}()

	conn.SetReadDeadline(time.Now().Add(pongWait))
	conn.SetPongHandler(func(string) error {
		conn.SetReadDeadline(time.Now().Add(pongWait))
		return nil
	})
	conn.SetReadLimit(maxMessageSize)

	for {
		_, message, err := conn.ReadMessage()
		if err != nil {
			if ws.IsUnexpectedCloseError(err, ws.CloseGoingAway, ws.CloseAbnormalClosure) {
				logger.Debugf("WebSocket read error for client %s: %v", client.ID, err)
			}
			break
		}

		// Validate message size
		if len(message) > maxMessageSize {
			logger.Warningf("WebSocket message from client %s exceeds max size: %d bytes", client.ID, len(message))
			continue
		}

		// Handle incoming messages (e.g., subscription requests)
		// For now, we'll just log them
		logger.Debugf("Received WebSocket message from client %s: %s", client.ID, string(message))
	}
}

// writePump pumps messages from the hub to the WebSocket connection
func (w *WebSocketController) writePump(client *websocket.Client, conn *ws.Conn) {
	ticker := time.NewTicker(pingPeriod)
	defer func() {
		if r := common.Recover("WebSocket writePump panic"); r != nil {
			logger.Error("WebSocket writePump panic recovered:", r)
		}
		ticker.Stop()
		conn.Close()
	}()

	for {
		select {
		case message, ok := <-client.Send:
			conn.SetWriteDeadline(time.Now().Add(writeWait))
			if !ok {
				// Hub closed the channel
				conn.WriteMessage(ws.CloseMessage, []byte{})
				return
			}

			// Send each message individually (no batching)
			// This ensures each JSON message is sent separately and can be parsed correctly
			if err := conn.WriteMessage(ws.TextMessage, message); err != nil {
				logger.Debugf("WebSocket write error for client %s: %v", client.ID, err)
				return
			}

		case <-ticker.C:
			conn.SetWriteDeadline(time.Now().Add(writeWait))
			if err := conn.WriteMessage(ws.PingMessage, nil); err != nil {
				logger.Debugf("WebSocket ping error for client %s: %v", client.ID, err)
				return
			}
		}
	}
}