SpringBoot都是绿色的
## 前言

由于网上各式各样的 WebSocket 代码都太旧了,前端后端代码的坑又特别多,所以总结了一套自己常用的 WebSocket 代码,这个代码仅针对单体应用


背景

前端需要与后端建立一个长连接,用于传递新用户订阅的通知,并实时发送给前端的所有用户。传统的代码能用,但却面临以下痛点:

  • 如何校验连接的安全性
  • 如何防止连接断开

对于 如何校验连接的安全性 ,网上许多拆东墙补西墙的方法,是在 URL 做手脚,也就是将 tokenws://xx.xx.xx.xx:8080/application/{userId}?token={token} 这样的形式发送给后端,后端通过获取请求行中的内容,得到 token,再进行验证操作。

但有一天我通过 goaccess 分析 Nginx 日志,一下子就看到了许多用户的 ws 请求,上面的 token 写的可真是一清二楚,这对于生产环境来说,太危险了,如果有人得到了这份日志,就可以匿名做任何事情。

如果不在 URL 做手脚,但 WebSocket 又不支持请求体存放内容,那该怎么办呢?

答案就是在 Sec-WebSocket-Protocol 上,把 token 写入即可。

Sec-WebSocket-Protocol是WebSocket协议的一个扩展,它允许客户端和服务器在握手过程中传递一些协议名称,以便它们可以在建立WebSocket连接后使用这些协议进行通信。

在客户端和服务器进行WebSocket握手时,客户端可以在请求头中使用Sec-WebSocket-Protocol字段指定所支持的协议,服务器在响应头中返回一个协议,以指明实际采用的协议。如果客户端和服务器都支持同一个协议,则建立的WebSocket连接将使用该协议进行通信。

这个扩展的使用场景比较广泛,例如在实时音视频通话、游戏开发、实时消息传递等领域,可以使用Sec-WebSocket-Protocol来实现更高效的通信。

对于 如何防止连接断开,这个方法就和网上流传的一样,满足心跳机制即可。


版本说明

  • SpringBoot:2.7.7
  • JDK: 1.8

Maven引入

1
2
3
4
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>

后端代码

这是一个基于 Spring 框架的 WebSocket 配置类,用于自动注册 WebSocket 端点的 bean,以便服务器能够处理客户端的 WebSocket 请求。该配置类通常用于配置 WebSocket 服务端的相关配置,以便与客户端建立 WebSocket 连接并进行双向通信。

该类中的 serverEndpointExporter() 方法返回一个 ServerEndpointExporter 实例,该实例会在Spring应用程序启动时扫描所有标注了 @ServerEndpoint 注解的类,并将其注册为 WebSocket 端点。

该类中的 createWebSocketContainer() 方法返回一个 ServletServerContainerFactoryBean 实例,通过设置 maxTextMessageBufferSizemaxBinaryMessageBufferSize 属性,可以设置文本消息缓冲区二进制消息缓冲区的大小,以便控制 WebSocket消息的大小限制。

P.S. 因为我只用到了发送文本,没有发送图片的功能,请自行斟酌要发送的字节大小,如果在 ServletServerContainerFactoryBean 中没有设置 maxTextMessageBufferSizemaxBinaryMessageBufferSize 属性,则默认情况下它们的值都是 -1,表示没有限制,可以接受任意大小WebSocket消息,这是极其危险的事情

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean;

@Configuration
@Slf4j
public class WebSocketConfig {

@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}

@Bean
public ServletServerContainerFactoryBean createWebSocketContainer() {
ServletServerContainerFactoryBean bean = new ServletServerContainerFactoryBean();
bean.setMaxTextMessageBufferSize(8192);
bean.setMaxBinaryMessageBufferSize(8192);
return bean;
}
}

这是一个枚举,用于标识消息的类型(可跟前端约束、自定义这里的属性和内容)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public enum WebSocketMessageTypeEnum {
HEART("心跳", "1"),
MESSAGE("普通消息", "2"));

private final String name;
private final String value;

WebSocketMessageTypeEnum(String name, String value) {
this.name = name;
this.value = value;
}

public String getName() {
return name;
}

public String getValue() {
return value;
}
}

subprotocols = {"leopold"} 中的内容可以换成任意自定义内容,只要前后端发送的 Header 一致就可以,否则前后端会自动丢弃 WebSocket 的连接。在 // TODO your business token check 这里填入你业务系统的token校验即可

通过重写 CustomWebSocketConfigurator,获取到自定义的 Header,为后续方法建立基础。session.getBasicRemote()是同步发送。TASK_POOL 是一个固定大小的线程池,如果提交的任务数量超过了线程池的大小,那么这些任务将被放入队列中等待执行。如果某个线程在执行任务时发生了异常,它将停止执行当前任务,但线程池中的其他线程将继续执行未完成的任务。如果队满再放入任务,则会阻塞代码,并填入内存中,直到内存已满。

使用这个线程池的原因也是因为这种阻塞行为可以确保任务不会被丢失,同时避免创建过多的线程。如果你的消息允许丢失,可以选择其他线程池。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.ObjectUtils;

import javax.websocket.*;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;

/**
*
* Message Center Based On Websocket <br/>
*
* <p>1. We use <b>RESTFul API</b> to ctrl message center</p>
* <p>2. We don't choose NIO (Even though Tomcat Version > 6.0 is NIO), because :</p>
* <p>2.1 Our members are below to 1,000. (we should think about new NIO such
* as Netty if our online user up to 1,000)</p>
* <p>2.2 Our SpringBoot system is based on Tomcat. It's hard to change container and
* we don't have time to get Regression Test</p>
* <p>2.3 Tomcat NIO is a typical singleton thread model, but Netty is Reactor model
* which is used on high concurrency.</p><br/>
*
* <p> WARNING: </p>
* If you always catch {@link java.io.EOFException}, you should add connect time out
* on nginx such as {@code proxy_connect_timeout 3600s;} P.S. Netty is a better choice
* maybe I will recode this class :)<br/>
*
* @author leopold
* @version 1.0
*/
@ServerEndpoint(value = "/websocket/{userId}", subprotocols = {"wxatrep"}, configurator = WebSocketServer.CustomWebSocketConfigurator.class)
@Component
@Slf4j
public class WebSocketServer {

/**
* Our online member number. We used {@link java.util.concurrent.atomic}
* to synchronize integer so that we could get correct number.
*/
private static final AtomicInteger ONLINE_COUNT = new AtomicInteger(0);

/**
* Our online clients hashmap. We used {@link ConcurrentHashMap} to
* save sessions. The key is userId which is based on business system
* and the value is {@link Session}
*/
private static final ConcurrentHashMap<String, Session> CLIENTS_MAP = new ConcurrentHashMap<>();

/**
* Our send message task pool. We used {@link ExecutorService} to ctrl
* send thread pool. We used {@link Executors#newFixedThreadPool(int)}
* to create a max limit thread in pool. Normally we create a (cpu core
* num +1) in task pool.
* <br/>
* P.S. This pool maybe fill up your system memory when your thread task is doing
* slowly. If your task is not important, and it could be tolerated, you
* should use a bounded queue instead of it.
*/
private static final ExecutorService TASK_POOL = Executors.newFixedThreadPool(3);

/**
* Our userId which is from business system
*/
private volatile String userId;

public static class CustomWebSocketConfigurator extends ServerEndpointConfig.Configurator {
@Override
public void modifyHandshake(ServerEndpointConfig config, HandshakeRequest request, HandshakeResponse response) {
super.modifyHandshake(config, request, response);

// get sec-websocket-protocol
List<String> tokenList = request.getHeaders().get("sec-websocket-protocol");
if (tokenList != null && tokenList.size() > 0) {
String token = tokenList.get(0);
config.getUserProperties().put("token", token.split(",")[0].trim());
}
}
}

/**
* When client create a new websocket connection, we catch the code
* and put userId and session into {@link WebSocketServer#CLIENTS_MAP}
*
* @param userId userId
* @param session user session
*/
@OnOpen
public void onOpen(@PathParam("userId") String userId, Session session) throws Exception {
Map<String, Object> userProperties = session.getUserProperties();
Object tokenO = userProperties.get("token");
if (ObjectUtils.isEmpty(tokenO)) {
throw new Exception("token not found");
}
String token = tokenO.toString();

// TODO your business token check

this.userId = userId;
CLIENTS_MAP.put(userId, session);
log.info("A new webSocket connected, userId -> [{}], all connected num -> [{}]", userId,
ONLINE_COUNT.incrementAndGet());
}

/**
* When client closed a websocket connection, we catch the code
* and remove session from {@link WebSocketServer#CLIENTS_MAP} by
* userId. This remove method does nothing if the key is not in
* the map.
*/
@OnClose
public void onClose() {
if (!ObjectUtils.isEmpty(userId)) {
CLIENTS_MAP.remove(userId);
log.info("A webSocket closed, userId -> [{}], current connected num -> [{}]", userId,
ONLINE_COUNT.decrementAndGet());
}
}

/**
* When client send a message to service, then we catch the code.
* I did nothing because our business system didn't have to save it.
*
* @param message client send message
* @param session client session
*/
@OnMessage
public void onMessage(@PathParam("userId") String userId, String message, Session session) throws IOException {
// HEART CHECK
if (!ObjectUtils.isEmpty(message) && WebSocketMessageTypeEnum.HEART.name().equals(message)) {
sendMessage(message, session);
} else {
log.info("get message from client -> [{}]", message);
}
}

/**
*
* When the websocket goes errors, then we catch the code. For
* some reasons, we always catch {@link java.io.EOFException}
* because the http connections goes max connecting time.
*
*/
@OnError
public void onError(Session session, Throwable error) {
if (error instanceof java.io.EOFException) {
log.warn("ws normally disconnect with max connect timeout, session id -> [{}]", session.getId());
} else {
log.error("ws error -> {} session id -> [{}]", error.getMessage(), session.getId(), error);
}
try {
session.close();
} catch (IOException e) {
log.error("ws close error -> {} session id -> [{}]", error.getMessage(), session.getId(), error);
}
}

/**
* Broadcast message on all online user
* @param message message body
*/
public static void broadcastMessage(String message) {
TASK_POOL.submit(() -> {
for (Session session : CLIENTS_MAP.values()) {
try {
sendMessage(message, session);
} catch (IOException e) {
log.error("broadcast message error -> {}", e.getMessage(), e);
}
}
});
}

/**
* Send message to user by userId
* @param message message body
* @param userId userId
*/
public static void sendMessage(String message, String userId) {
for (String clientUserId : CLIENTS_MAP.keySet()) {
if (userId.equals(clientUserId)) {
TASK_POOL.submit(() -> {
try {
sendMessage(message, CLIENTS_MAP.get(clientUserId));
} catch (IOException e) {
log.error("send message error, userId -> [{}], message -> [{}], because -> [{}]",
userId, message, e.getMessage(), e);
}
});
break;
}
}
}


/**
* Send message to user by userIdList
* @param message message body
* @param userIdList userIdList
*/
public static void sendMessage(String message, List<String> userIdList) {
TASK_POOL.submit(() -> {
for (String clientUserId : CLIENTS_MAP.keySet()) {
if (userIdList.contains(clientUserId)) {
try {
sendMessage(message, CLIENTS_MAP.get(clientUserId));
} catch (IOException e) {
log.error("send message error, userId -> [{}], message -> [{}], because -> [{}]",
clientUserId, message, e.getMessage(), e);
}
}
}
});
}


// ------------------------------------------------------------------- private method


/**
* Core send message method
* @param message message body
* @param session user session
* @throws IOException session error normally
*/
private static void sendMessage(String message, Session session) throws IOException {
// Check session is not closed
if (session.isOpen()) {
// Not async sending message
session.getBasicRemote().sendText(message);
} else {
log.warn("session [{}] is closed, cannot send message", session.getId());
}
}
}

前端代码

由于我的 WebSocket 只需要登录后建立,所以我把登录后的 main.js 引入这个工具包即可 import socketPublic from '@/utils/websocket.js'

state.ws = new WebSocket(process.env.VUE_APP_WEBSOCKET_URL + store.getters.id, [store.getters.token, 'leopold']) 请自行替换为你的地址,但第二个参数请保持 [token, '自定义协议名称']

@/utils/websocket.js

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import Vue from 'vue'
import Vuex from 'vuex'
import store from '@/store'
const v = new Vue()
Vue.use(Vuex)

export default new Vuex.Store({
state: {
ws: null, // 建立的连接
lockReconnect: false, // 是否真正建立连接
timeout: 15000, // 15秒一次心跳
timeoutObj: null, // 心跳心跳倒计时
serverTimeoutObj: null, // 心跳倒计时
timeoutnum: null, // 断开 重连倒计时
msg: null // 接收到的信息
},
getters: {
// 获取接收的信息
socketMsgs: state => {
return state.msg
}
},
mutations: {
// 初始化ws 用户登录后调用
webSocketInit(state) {
const that = this
// this 创建一个state.ws对象【发送、接收、关闭socket都由这个对象操作】
state.ws = new WebSocket(process.env.VUE_APP_WEBSOCKET_URL + store.getters.id, [store.getters.token, 'leopold'])
state.ws.onopen = function(res) {
console.log('Connection success...')
/**
* 启动心跳检测
*/
that.commit('start')
}
state.ws.onmessage = function(res) {
if (res.data === 'HEART') {
// 收到服务器信息,心跳重置
that.commit('reset')
} else {
console.log('收到消息:' + res.data)
const msg = JSON.parse(res.data)
v.$notify({
title: msg.title,
message: msg.message,
type: msg.type,
duration: 5000
})
state.msg = res
}
}
state.ws.onclose = function(res) {
console.log('Connection closed...')
// 重连
that.commit('reconnect')
}
state.ws.onerror = function(res) {
console.log('Connection error...')
// 重连
that.commit('reconnect')
}
},
reconnect(state) {
// 重新连接
const that = this
if (state.lockReconnect) {
return
}
state.lockReconnect = true
// 没连接上会一直重连,30秒重试请求重连,设置延迟避免请求过多
state.timeoutnum &&
clearTimeout(state.timeoutnum)
state.timeoutnum = setTimeout(() => {
// 新连接
that.commit('webSocketInit')
state.lockReconnect = false
}, 5000)
},
reset(state) {
// 重置心跳
const that = this
// 清除时间
clearTimeout(state.timeoutObj)
clearTimeout(state.serverTimeoutObj)
// 重启心跳
that.commit('start')
},
start(state) {
// 开启心跳
var self = this
state.timeoutObj &&
clearTimeout(state.timeoutObj)
state.serverTimeoutObj &&
clearTimeout(state.serverTimeoutObj)
state.timeoutObj = setTimeout(() => {
// 这里发送一个心跳,后端收到后,返回一个心跳消息,
if (state.ws.readyState === 1) {
// 如果连接正常
state.ws.send('HEART')
} else {
// 否则重连
self.commit('reconnect')
}
state.serverTimeoutObj = setTimeout(function() {
// 超时关闭
state.ws.close()
}, state.timeout)
}, state.timeout)
}
},
actions: {
webSocketInit({
commit
}, url) {
commit('webSocketInit', url)
},
webSocketSend({
commit
}, p) {
commit('webSocketSend', p)
}
}
})

效果演示

image-20230301170704045

image-20230301170741938

image-20230301165411922


其他小知识

来源:ChatGPT

image-20230302085656025