Jelajahi Sumber

add: websocket库支持自定义headers https://gitee.com/openLuat/LuatOS/issues/I6B52C

Wendal Chen 3 tahun lalu
induk
melakukan
75b1996ec8

+ 95 - 14
components/network/websocket/luat_lib_websocket.c

@@ -57,6 +57,7 @@ static luat_websocket_ctrl_t *get_websocket_ctrl(lua_State *L)
 
 static int32_t l_websocket_callback(lua_State *L, void *ptr)
 {
+	(void)ptr;
 	rtos_msg_t *msg = (rtos_msg_t *)lua_topointer(L, -1);
 	luat_websocket_ctrl_t *websocket_ctrl = (luat_websocket_ctrl_t *)msg->ptr;
 	luat_websocket_pkg_t pkg = {0};
@@ -198,7 +199,7 @@ static int l_websocket_create(lua_State *L)
 	luat_websocket_connopts_t opts = {0};
 
 	// 连接参数相关
-	const char *ip;
+	// const char *ip;
 	size_t ip_len = 0;
 #ifdef LUAT_USE_LWIP
 	websocket_ctrl->ip_addr.type = 0xff;
@@ -295,13 +296,13 @@ static int l_websocket_autoreconn(lua_State *L)
 @string 待发送的数据,必填
 @int 是否为最后一帧,默认1
 @int 操作码, 默认为字符串帧
-@return int 消息id, 当qos为1或2时会有效值. 若底层返回是否, 会返回nil
+@return bool 成功返回true,否则为false或者nil
 @usage
 wsc:publish("/luatos/123456", "123")
 */
 static int l_websocket_send(lua_State *L)
 {
-	uint32_t payload_len = 0;
+	size_t payload_len = 0;
 	luat_websocket_ctrl_t *websocket_ctrl = get_websocket_ctrl(L);
 	const char *payload = NULL;
 	luat_zbuff_t *buff = NULL;
@@ -327,7 +328,8 @@ static int l_websocket_send(lua_State *L)
 		.plen = payload_len,
 		.payload = payload};
 	ret = luat_websocket_send_frame(websocket_ctrl, &pkg);
-	return 0;
+	lua_pushboolean(L, ret == 0 ? 1 : 0);
+	return 1;
 }
 
 /*
@@ -355,7 +357,7 @@ websocket客户端是否就绪
 @api wsc:ready()
 @return bool 客户端是否就绪
 @usage
-local error = wsc:ready()
+local stat = wsc:ready()
 */
 static int l_websocket_ready(lua_State *L)
 {
@@ -364,6 +366,82 @@ static int l_websocket_ready(lua_State *L)
 	return 1;
 }
 
+/*
+设置额外的headers
+@api wsc:headers(headers)
+@table/string 可以是table,也可以是字符串
+@return bool 客户端是否就绪
+@usage
+-- table形式
+wsc:headers({
+	Auth="Basic ABCDEFGG"
+})
+-- 字符串形式
+wsc:headers("Auth: Basic ABCDERG\r\n")
+*/
+static int l_websocket_headers(lua_State *L)
+{
+	luat_websocket_ctrl_t *websocket_ctrl = get_websocket_ctrl(L);
+	if (!lua_istable(L, 2) && !lua_isstring(L, 2)) {
+		return  0;
+	}
+	#define WS_HEADER_MAX (1024)
+	char* buff = luat_heap_malloc(WS_HEADER_MAX);
+	memset(buff, 0, WS_HEADER_MAX);
+	if (lua_istable(L, 2)) {
+		size_t name_sz = 0;
+		size_t value_sz = 0;
+		lua_pushnil(L);
+		while (lua_next(L, 2) != 0) {
+			const char *name = lua_tolstring(L, -2, &name_sz);
+			const char *value = lua_tolstring(L, -1, &value_sz);
+			if (name_sz == 0 || value_sz == 0 || name_sz + value_sz > 256) {
+				LLOGW("bad header %s %s", name, value);
+				luat_heap_free(buff);
+				return 0;
+			}
+			memcpy(buff + strlen(buff), name, name_sz);
+			memcpy(buff + strlen(buff), ":", 1);
+			if (WS_HEADER_MAX - strlen(buff) < value_sz * 2) {
+				LLOGW("bad header %s %s, too large", name, value);
+				luat_heap_free(buff);
+				return 0;
+			}
+			for (size_t i = 0; i < value_sz; i++)
+			{
+				switch (value[i])
+				{
+				case '*':
+				case '-':
+				case '.':
+				case '_':
+				case ' ':
+					sprintf_(buff + strlen(buff), "%%%02X", value[i]);
+					break;
+				default:
+					buff[strlen(buff)] = value[i];
+					break;
+				}
+			}
+			lua_pop(L, 1);
+			memcpy(buff + strlen(buff), "\r\n", 2);
+		}
+	}
+	else {
+		size_t len = 0;
+		const char* data = luaL_checklstring(L, 2, &len);
+		if (len > 1023) {
+			LLOGW("headers too large size %d", len);
+			luat_heap_free(buff);
+			return 0;
+		}
+		memcpy(buff, data, len);
+	}
+	luat_websocket_set_headers(websocket_ctrl, buff);
+	lua_pushboolean(L, 1);
+	return 1;
+}
+
 static int _websocket_struct_newindex(lua_State *L);
 
 void luat_websocket_struct_init(lua_State *L)
@@ -377,19 +455,22 @@ void luat_websocket_struct_init(lua_State *L)
 #include "rotable2.h"
 const rotable_Reg_t reg_websocket[] =
 	{
-		{"create", ROREG_FUNC(l_websocket_create)},
-		{"on", ROREG_FUNC(l_websocket_on)},
-		{"connect", ROREG_FUNC(l_websocket_connect)},
-		{"autoreconn", ROREG_FUNC(l_websocket_autoreconn)},
-		{"send", ROREG_FUNC(l_websocket_send)},
-		{"close", ROREG_FUNC(l_websocket_close)},
-		{"ready", ROREG_FUNC(l_websocket_ready)},
+		{"create", 			ROREG_FUNC(l_websocket_create)},
+		{"on", 				ROREG_FUNC(l_websocket_on)},
+		{"connect", 		ROREG_FUNC(l_websocket_connect)},
+		{"autoreconn", 		ROREG_FUNC(l_websocket_autoreconn)},
+		{"send", 			ROREG_FUNC(l_websocket_send)},
+		{"close", 			ROREG_FUNC(l_websocket_close)},
+		{"ready", 			ROREG_FUNC(l_websocket_ready)},
+		{"headers", 		ROREG_FUNC(l_websocket_headers)},
+		{"debug",           ROREG_FUNC(l_websocket_set_debug)},
 
-		{NULL, ROREG_INT(0)}};
+		{NULL, 				ROREG_INT(0)}
+};
 
 int _websocket_struct_newindex(lua_State *L)
 {
-	rotable_Reg_t *reg = reg_websocket;
+	const rotable_Reg_t *reg = reg_websocket;
 	const char *key = luaL_checkstring(L, 2);
 	while (1)
 	{

+ 33 - 10
components/network/websocket/luat_websocket.c

@@ -281,6 +281,10 @@ void luat_websocket_release_socket(luat_websocket_ctrl_t *websocket_ctrl)
 		luat_release_rtos_timer(websocket_ctrl->reconnect_timer);
     	websocket_ctrl->reconnect_timer = NULL;
 	}
+	if (websocket_ctrl->headers) {
+		luat_heap_free(websocket_ctrl->headers);
+		websocket_ctrl->headers = NULL;
+	}
 	if (websocket_ctrl->netc)
 	{
 		network_release_ctrl(websocket_ctrl->netc);
@@ -288,6 +292,13 @@ void luat_websocket_release_socket(luat_websocket_ctrl_t *websocket_ctrl)
 	}
 }
 
+static const char* ws_headers = 
+						"Upgrade: websocket\r\n"
+						"Connection: Upgrade\r\n"
+						"Sec-WebSocket-Key: w4v7O6xFTi36lq3RNcgctw==\r\n"
+						"Sec-WebSocket-Version: 13\r\n"
+						"\r\n";
+
 static int websocket_connect(luat_websocket_ctrl_t *websocket_ctrl)
 {
 	LLOGD("request host %s port %d uri %s", websocket_ctrl->host, websocket_ctrl->remote_port, websocket_ctrl->uri);
@@ -295,15 +306,14 @@ static int websocket_connect(luat_websocket_ctrl_t *websocket_ctrl)
 	int ret = snprintf_((char*)websocket_ctrl->pkg_buff,
 						WEBSOCKET_RECV_BUF_LEN_MAX,
 						"GET %s HTTP/1.1\r\n"
-						"Host: %s\r\n"
-						"Upgrade: websocket\r\n"
-						"Connection: Upgrade\r\n"
-						"Sec-WebSocket-Key: w4v7O6xFTi36lq3RNcgctw==\r\n"
-						"Sec-WebSocket-Version: 13\r\n"
-						"\r\n",
+						"Host: %s\r\n",
 						websocket_ctrl->uri, websocket_ctrl->host);
-	LLOGD("Request %s", websocket_ctrl->pkg_buff);
+	//LLOGD("Request %s", websocket_ctrl->pkg_buff);
 	ret = luat_websocket_send_packet(websocket_ctrl, websocket_ctrl->pkg_buff, ret);
+	if (websocket_ctrl->headers) {
+		luat_websocket_send_packet(websocket_ctrl, websocket_ctrl->headers, strlen(websocket_ctrl->headers));
+	}
+	luat_websocket_send_packet(websocket_ctrl, ws_headers, strlen(ws_headers));
 	LLOGD("websocket_connect ret %d", ret);
 	return ret;
 }
@@ -441,7 +451,7 @@ static int websocket_parse(luat_websocket_ctrl_t *websocket_ctrl)
 			return -1;
 		}
 		memcpy(buff, buf, pkg_len);
-		l_luat_websocket_msg_cb(websocket_ctrl, WEBSOCKET_MSG_PUBLISH, buff);
+		l_luat_websocket_msg_cb(websocket_ctrl, WEBSOCKET_MSG_PUBLISH, (int)buff);
 	}
 
 	// 处理完成后, 如果还有数据, 移动数据, 继续处理
@@ -457,8 +467,8 @@ static int websocket_parse(luat_websocket_ctrl_t *websocket_ctrl)
 int luat_websocket_read_packet(luat_websocket_ctrl_t *websocket_ctrl)
 {
 	// LLOGD("luat_websocket_read_packet websocket_ctrl->buffer_offset:%d",websocket_ctrl->buffer_offset);
-	int ret = -1;
-	uint8_t *read_buff = NULL;
+	// int ret = -1;
+	// uint8_t *read_buff = NULL;
 	uint32_t total_len = 0;
 	uint32_t rx_len = 0;
 	int result = network_rx(websocket_ctrl->netc, NULL, 0, 0, NULL, NULL, &total_len);
@@ -628,3 +638,16 @@ int luat_websocket_connect(luat_websocket_ctrl_t *websocket_ctrl)
 	}
 	return 0;
 }
+
+int luat_websocket_set_headers(luat_websocket_ctrl_t *websocket_ctrl, const char *headers) {
+	if (websocket_ctrl == NULL)
+		return 0;
+	if (websocket_ctrl->headers != NULL) {
+		luat_heap_free(websocket_ctrl->headers);
+		websocket_ctrl->headers = NULL;
+	}
+	if (headers) {
+		websocket_ctrl->headers = headers;
+	}
+	return 0;
+}

+ 2 - 1
components/network/websocket/luat_websocket.h

@@ -32,6 +32,7 @@ typedef struct
 	void *reconnect_timer;	 // websocket重连定时器
 	void *ping_timer;		 // websocket_ping定时器
 	int websocket_ref;		 // 强制引用自身避免被GC
+	char* headers;
 } luat_websocket_ctrl_t;
 
 typedef struct luat_websocket_connopts
@@ -69,5 +70,5 @@ int luat_websocket_init(luat_websocket_ctrl_t *websocket_ctrl, int adapter_index
 int luat_websocket_set_connopts(luat_websocket_ctrl_t *websocket_ctrl, const char *url);
 int luat_websocket_payload(char *buff, luat_websocket_pkg_t *pkg, size_t limit);
 int luat_websocket_send_frame(luat_websocket_ctrl_t *websocket_ctrl, luat_websocket_pkg_t *pkg);
-
+int luat_websocket_set_headers(luat_websocket_ctrl_t *websocket_ctrl, const char *headers);
 #endif

+ 5 - 2
demo/websocket/main.lua

@@ -19,8 +19,8 @@ local wsc = nil
 
 sys.taskInit(function()
     if rtos.bsp():startsWith("ESP32") then
-        local ssid = "uiot123"
-        local password = "12348888"
+        local ssid = "uiot"
+        local password = "1234567890"
         log.info("wifi", ssid, password)
         -- TODO 改成esptouch配网
         LED = gpio.setup(12, 0, gpio.PULLUP)
@@ -46,6 +46,9 @@ sys.taskInit(function()
 
     -- 这是个测试服务, 当发送的是json,且action=echo,就会回显所发送的内容
     wsc = websocket.create(nil, "ws://echo.airtun.air32.cn/ws/echo")
+    if wsc.headers then
+        wsc:headers({Auth="Basic ABCDEGG"})
+    end
     wsc:autoreconn(true, 3000) -- 自动重连机制
     wsc:on(function(wsc, event, data, fin, optcode)
         -- event 事件, 当前有conack和recv