Просмотр исходного кода

update: netdrv,napt,修正端口位图计算错误

Wendal Chen 3 месяцев назад
Родитель
Сommit
b98b8f4b18

+ 4 - 0
components/network/netdrv/include/luat_netdrv_napt.h

@@ -80,6 +80,9 @@ int luat_napt_icmp_handle(napt_ctx_t* ctx);
 int luat_napt_tcp_handle(napt_ctx_t* ctx);
 int luat_napt_udp_handle(napt_ctx_t* ctx);
 
+void luat_netdrv_napt_tcp_cleanup(void);
+void luat_netdrv_napt_udp_cleanup(void);
+
 int luat_netdrv_napt_pkg_input(int id, uint8_t* buff, size_t len);
 
 int luat_netdrv_napt_pkg_input_pbuf(int id, struct pbuf* p);
@@ -89,5 +92,6 @@ int luat_netdrv_napt_tcp_wan2lan(napt_ctx_t* ctx, luat_netdrv_napt_tcpudp_t* map
 int luat_netdrv_napt_tcp_lan2wan(napt_ctx_t* ctx, luat_netdrv_napt_tcpudp_t* mapping, luat_netdrv_napt_ctx_t *napt_ctx);
 
 void luat_netdrv_napt_enable(int adapter_id);
+void luat_netdrv_napt_disable(void);
 
 #endif

+ 56 - 10
components/network/netdrv/src/luat_netdrv_napt.c

@@ -39,6 +39,8 @@ luat_netdrv_napt_ctx_t *g_napt_udp_ctx;
 // 端口分配
 #define NAPT_PORT_RANGE_START     0x1BBC
 #define NAPT_PORT_RANGE_END       0x6AAA
+#define NAPT_PORT_COUNT           (NAPT_PORT_RANGE_END - NAPT_PORT_RANGE_START + 1)
+#define NAPT_PORT_BITMAP_SIZE     ((NAPT_PORT_COUNT + 31) / 32 * sizeof(uint32_t))
 
 #define u32 uint32_t
 #define u16 uint16_t
@@ -225,9 +227,15 @@ static int ctx_init(luat_netdrv_napt_ctx_t** ctx_ptrptr) {
     }
     memset(ctx, 0, sizeof(luat_netdrv_napt_ctx_t));
     luat_rtos_mutex_create(&ctx->lock);
-    size_t port_len = (NAPT_PORT_RANGE_END - NAPT_PORT_RANGE_START) / 4;
-    ctx->port_used = luat_heap_malloc(port_len + 8);
-    memset(ctx->port_used, 0, port_len + 8);
+    // 正确计算端口位图大小: (端口数 + 31) / 32 * 4 字节
+    size_t port_len = NAPT_PORT_BITMAP_SIZE;
+    ctx->port_used = luat_heap_malloc(port_len);
+    if (ctx->port_used == NULL) {
+        LLOGE("初始化napt port_used失败");
+        luat_heap_free(ctx);
+        return -1;
+    }
+    memset(ctx->port_used, 0, port_len);
     ctx->clean_tm = 1;
     ctx->item_max = NAPT_TCP_MAP_ITEM_MAX;
 
@@ -248,12 +256,48 @@ void luat_netdrv_napt_enable(int adapter_id) {
     }
     s_gw_adapter_id = adapter_id;
 }
+
+void luat_netdrv_napt_disable(void) {
+    s_gw_adapter_id = -1;
+    
+    // 清理TCP上下文
+    if (g_napt_tcp_ctx) {
+        if (g_napt_tcp_ctx->port_used) {
+            luat_heap_free(g_napt_tcp_ctx->port_used);
+            g_napt_tcp_ctx->port_used = NULL;
+        }
+        luat_rtos_mutex_delete(g_napt_tcp_ctx->lock);
+        luat_heap_free(g_napt_tcp_ctx);
+        g_napt_tcp_ctx = NULL;
+    }
+    
+    // 清理UDP上下文
+    if (g_napt_udp_ctx) {
+        if (g_napt_udp_ctx->port_used) {
+            luat_heap_free(g_napt_udp_ctx->port_used);
+            g_napt_udp_ctx->port_used = NULL;
+        }
+        luat_rtos_mutex_delete(g_napt_udp_ctx->lock);
+        luat_heap_free(g_napt_udp_ctx);
+        g_napt_udp_ctx = NULL;
+    }
+    
+    // 清理缓冲区
+    if (napt_buff) {
+        luat_heap_free(napt_buff);
+        napt_buff = NULL;
+    }
+    
+    // 调用TCP/UDP的cleanup函数
+    luat_netdrv_napt_tcp_cleanup();
+    luat_netdrv_napt_udp_cleanup();
+}
 __NETDRV_CODE_IN_RAM__ static size_t luat_napt_tcp_port_alloc(luat_netdrv_napt_ctx_t *napt_ctx) {
     size_t offset;
     size_t soffset;
-    for (size_t i = 0; i <= NAPT_PORT_RANGE_END - NAPT_PORT_RANGE_START; i++) {
-        offset = i / ( 4 * 8);
-        soffset = i % ( 4 * 8);
+    for (size_t i = 0; i < NAPT_PORT_COUNT; i++) {
+        offset = i / 32;  // 每个uint32_t占32位
+        soffset = i % 32;
         if ((napt_ctx->port_used[offset] & (1 << soffset)) == 0) {
             napt_ctx->port_used[offset] |= (1 << soffset);
             return i + NAPT_PORT_RANGE_START;
@@ -313,10 +357,12 @@ __NETDRV_CODE_IN_RAM__ static void mapping_cleanup(luat_netdrv_napt_ctx_t *napt_
             it->is_vaild = 0;
             it->tm_ms = 0;
             port = it->wnet_local_port - NAPT_PORT_RANGE_START;
-            offset = port / ( 4 * 8);
-            soffset = port % ( 4 * 8);
-            if (offset > 1024) {
-                LLOGE("非法的offset %d", offset);
+            offset = port / 32;  // 每个uint32_t占32位
+            soffset = port % 32;
+            // 检查offset是否在有效范围内
+            size_t max_offset = (NAPT_PORT_COUNT + 31) / 32;
+            if (offset >= max_offset) {
+                LLOGE("非法的offset %d, 最大值应为 %d", offset, max_offset - 1);
             }
             else {
                 napt_ctx->port_used[offset] &= (~(1 << soffset));

+ 7 - 0
components/network/netdrv/src/luat_netdrv_napt_tcp.c

@@ -17,6 +17,13 @@
 static uint8_t *tcp_buff;
 extern luat_netdrv_napt_ctx_t *g_napt_tcp_ctx;
 
+void luat_netdrv_napt_tcp_cleanup(void) {
+    if (tcp_buff) {
+        luat_heap_free(tcp_buff);
+        tcp_buff = NULL;
+    }
+}
+
 #define u32 uint32_t
 #define u16 uint16_t
 #define u8 uint8_t

+ 7 - 0
components/network/netdrv/src/luat_netdrv_napt_udp.c

@@ -20,6 +20,13 @@ extern luat_netdrv_napt_ctx_t *g_napt_udp_ctx;
 #define NAPT_ETH_HDR_LEN sizeof(struct ethhdr)
 
 static uint8_t *udp_buff;
+
+void luat_netdrv_napt_udp_cleanup(void) {
+    if (udp_buff) {
+        luat_heap_free(udp_buff);
+        udp_buff = NULL;
+    }
+}
 __NETDRV_CODE_IN_RAM__ int luat_napt_udp_handle(napt_ctx_t *ctx)
 {
     uint16_t iphdr_len = (ctx->iphdr->_v_hl & 0x0F) * 4;