Просмотр исходного кода

update: netdrv,napt,优化checksum计算效率

Wendal Chen 3 месяцев назад
Родитель
Сommit
df5dd21c79

+ 39 - 14
components/network/netdrv/src/luat_netdrv_napt_icmp.c

@@ -26,6 +26,25 @@ static luat_netdrv_napt_icmp_t* icmps;
 #define u8 uint8_t
 #define NAPT_ETH_HDR_LEN             sizeof(struct ethhdr)
 
+// Incrementally update checksum when a 16-bit field (network order) changes.
+static inline uint16_t napt_chksum_replace_u16(uint16_t sum_net, uint16_t old_net, uint16_t new_net)
+{
+    uint32_t acc = (~lwip_ntohs(sum_net) & 0xFFFFU) + (~lwip_ntohs(old_net) & 0xFFFFU) + lwip_ntohs(new_net);
+    acc = (acc >> 16) + (acc & 0xFFFFU);
+    acc += (acc >> 16);
+    return lwip_htons((uint16_t)(~acc));
+}
+
+// Incrementally update checksum when a 32-bit field (network order) changes.
+static inline uint16_t napt_chksum_replace_u32(uint16_t sum_net, uint32_t old_net, uint32_t new_net)
+{
+    const uint16_t *old16 = (const uint16_t *)&old_net;
+    const uint16_t *new16 = (const uint16_t *)&new_net;
+    sum_net = napt_chksum_replace_u16(sum_net, old16[0], new16[0]);
+    sum_net = napt_chksum_replace_u16(sum_net, old16[1], new16[1]);
+    return sum_net;
+}
+
 
 static u16 luat_napt_icmp_id_alloc(void)
 {
@@ -85,16 +104,19 @@ int luat_napt_icmp_handle(napt_ctx_t* ctx) {
             }
             // 找到映射关系了!!!
             // LLOGD("ICMP id wnet %d inet %u", icmp_hdr->id, icmps[i].inet_id);
-            // 修改目标ID
+            // 修改目标ID并增量修正icmp checksum
+            uint16_t old_icmp_id = icmp_hdr->id;
+            uint16_t icmp_sum = icmp_hdr->chksum;
+            uint16_t ip_sum = ip_hdr->_chksum;
             icmp_hdr->id = icmps[i].inet_id;
-            // 重新计算icmp的checksum
-            icmp_hdr->chksum = 0;
-            icmp_hdr->chksum = alg_iphdr_chksum((u16 *)icmp_hdr, ntohs(ip_hdr->_len) - iphdr_len);
+            icmp_sum = napt_chksum_replace_u16(icmp_sum, old_icmp_id, icmp_hdr->id);
+            icmp_hdr->chksum = icmp_sum;
 
-            // 修改目标地址,并重新计算ip的checksum
+            // 修改目标地址,并增量更新ip的checksum
+            uint32_t old_dst_ip = ip_hdr->dest.addr;
             ip_hdr->dest.addr = icmps[i].inet_ip;
-            ip_hdr->_chksum = 0;
-            ip_hdr->_chksum = alg_iphdr_chksum((u16 *)ip_hdr, iphdr_len);
+            ip_sum = napt_chksum_replace_u32(ip_sum, old_dst_ip, icmps[i].inet_ip);
+            IPH_CHKSUM_SET(ip_hdr, ip_sum);
 
             // 如果是ETH包, 那还需要修改源MAC和目标MAC
             if (ctx->eth) {
@@ -187,14 +209,17 @@ int luat_napt_icmp_handle(napt_ctx_t* ctx) {
             }
         }
         // 2. 修改信息
-        ip_hdr->src.addr = ip_addr_get_ip4_u32(&gw->netif->ip_addr);
+        uint16_t old_icmp_id = icmp_hdr->id;
+        uint32_t old_src_ip = ip_hdr->src.addr;
+        uint32_t new_src_ip = ip_addr_get_ip4_u32(&gw->netif->ip_addr);
+        uint16_t ip_sum = ip_hdr->_chksum;
+        uint16_t icmp_sum = icmp_hdr->chksum;
+        ip_hdr->src.addr = new_src_ip;
+        ip_sum = napt_chksum_replace_u32(ip_sum, old_src_ip, new_src_ip);
+        IPH_CHKSUM_SET(ip_hdr, ip_sum);
         icmp_hdr->id = it->wnet_id;
-        // 3. 重新计算checksum
-        icmp_hdr->chksum = 0;
-        icmp_hdr->chksum = alg_iphdr_chksum((u16 *)icmp_hdr, ntohs(ip_hdr->_len) - iphdr_len);
-        // 4. 计算IP包的checksum
-        ip_hdr->_chksum = 0;
-        ip_hdr->_chksum = alg_iphdr_chksum((u16 *)ip_hdr, iphdr_len);
+        icmp_sum = napt_chksum_replace_u16(icmp_sum, old_icmp_id, icmp_hdr->id);
+        icmp_hdr->chksum = icmp_sum;
 
         // 5. 如果是ETH包, 还得修正MAC地址
         if (ctx->eth) {

+ 56 - 16
components/network/netdrv/src/luat_netdrv_napt_tcp.c

@@ -22,6 +22,25 @@ extern luat_netdrv_napt_ctx_t *g_napt_tcp_ctx;
 #define u8 uint8_t
 #define NAPT_ETH_HDR_LEN             sizeof(struct ethhdr)
 
+// Incrementally update checksum when a 16-bit field (network order) changes.
+static inline uint16_t napt_chksum_replace_u16(uint16_t sum_net, uint16_t old_net, uint16_t new_net)
+{
+    uint32_t acc = (~lwip_ntohs(sum_net) & 0xFFFFU) + (~lwip_ntohs(old_net) & 0xFFFFU) + lwip_ntohs(new_net);
+    acc = (acc >> 16) + (acc & 0xFFFFU);
+    acc += (acc >> 16);
+    return lwip_htons((uint16_t)(~acc));
+}
+
+// Incrementally update checksum when a 32-bit field (network order) changes.
+static inline uint16_t napt_chksum_replace_u32(uint16_t sum_net, uint32_t old_net, uint32_t new_net)
+{
+    const uint16_t *old16 = (const uint16_t *)&old_net;
+    const uint16_t *new16 = (const uint16_t *)&new_net;
+    sum_net = napt_chksum_replace_u16(sum_net, old16[0], new16[0]);
+    sum_net = napt_chksum_replace_u16(sum_net, old16[1], new16[1]);
+    return sum_net;
+}
+
 __NETDRV_CODE_IN_RAM__ int luat_napt_tcp_handle(napt_ctx_t* ctx) {
     uint16_t iphdr_len = (ctx->iphdr->_v_hl & 0x0F) * 4;
     struct ip_hdr* ip_hdr = ctx->iphdr;
@@ -44,22 +63,32 @@ __NETDRV_CODE_IN_RAM__ int luat_napt_tcp_handle(napt_ctx_t* ctx) {
         ret = luat_netdrv_napt_tcp_wan2lan(ctx, &mapping, g_napt_tcp_ctx);
         if (ret == 0) {
             // 修改目标端口
+            uint16_t old_dst_port = tcp_hdr->dest;
+            uint32_t old_dst_ip = ip_hdr->dest.addr;
+            uint16_t ip_sum = ip_hdr->_chksum;
+            uint16_t tcp_sum = tcp_hdr->chksum;
             tcp_hdr->dest = mapping.inet_port;
 
-            // 修改目标地址到内网地址,并重新计算ip的checksu
+            // 修改目标地址到内网地址,并增量更新ip的checksum
             ip_hdr->dest.addr = mapping.inet_ip;
-            ip_hdr->_chksum = 0;
-            ip_hdr->_chksum = alg_iphdr_chksum((u16 *)ip_hdr, iphdr_len);
+            ip_sum = napt_chksum_replace_u32(ip_sum, old_dst_ip, mapping.inet_ip);
+            IPH_CHKSUM_SET(ip_hdr, ip_sum);
 
             // 重新计算icmp的checksum
-            // if (tcp_hdr->chksum) {
-                tcp_hdr->chksum = 0;
+            if (tcp_sum)
+            {
+                tcp_sum = napt_chksum_replace_u32(tcp_sum, old_dst_ip, mapping.inet_ip);
+                tcp_sum = napt_chksum_replace_u16(tcp_sum, old_dst_port, mapping.inet_port);
+                tcp_hdr->chksum = tcp_sum;
+            }
+            else
+            {
                 tcp_hdr->chksum = alg_tcpudphdr_chksum(ip_hdr->src.addr,
-                                               ip_hdr->dest.addr,
-                                               IP_PROTO_TCP,
-                                               (u16 *)tcp_hdr,
-                                               ntohs(ip_hdr->_len) - iphdr_len);
-            // }
+                                                       ip_hdr->dest.addr,
+                                                       IP_PROTO_TCP,
+                                                       (u16 *)tcp_hdr,
+                                                       ntohs(ip_hdr->_len) - iphdr_len);
+            }
 
             // 如果是ETH包, 那还需要修改源MAC和目标MAC
             if (ctx->eth) {
@@ -113,20 +142,31 @@ __NETDRV_CODE_IN_RAM__ int luat_napt_tcp_handle(napt_ctx_t* ctx) {
         it_map = &mapping;
 
         // 2. 修改信息
-        ip_hdr->src.addr = ip_addr_get_ip4_u32(&gw->netif->ip_addr);
+        uint16_t old_src_port = tcp_hdr->src;
+        uint32_t old_src_ip = ip_hdr->src.addr;
+        uint32_t new_src_ip = ip_addr_get_ip4_u32(&gw->netif->ip_addr);
+        uint16_t ip_sum = ip_hdr->_chksum;
+        uint16_t tcp_sum = tcp_hdr->chksum;
+        ip_hdr->src.addr = new_src_ip;
+        ip_sum = napt_chksum_replace_u32(ip_sum, old_src_ip, new_src_ip);
+        IPH_CHKSUM_SET(ip_hdr, ip_sum);
         tcp_hdr->src = it_map->wnet_local_port;
         // 3. 与ICMP不同, 先计算IP的checksum
-        ip_hdr->_chksum = 0;
-        ip_hdr->_chksum = alg_iphdr_chksum((u16 *)ip_hdr, iphdr_len);
         // 4. 计算IP包的checksum
-        // if (tcp_hdr->chksum) {
-            tcp_hdr->chksum = 0;
+        if (tcp_sum)
+        {
+            tcp_sum = napt_chksum_replace_u32(tcp_sum, old_src_ip, new_src_ip);
+            tcp_sum = napt_chksum_replace_u16(tcp_sum, old_src_port, it_map->wnet_local_port);
+            tcp_hdr->chksum = tcp_sum;
+        }
+        else
+        {
             tcp_hdr->chksum = alg_tcpudphdr_chksum(ip_hdr->src.addr,
                                                    ip_hdr->dest.addr,
                                                    IP_PROTO_TCP,
                                                    (u16 *)tcp_hdr,
                                                    ntohs(ip_hdr->_len) - iphdr_len);
-        // }
+        }
 
         // 发送出去
         if (gw && gw->dataout && gw->netif) {

+ 42 - 9
components/network/netdrv/src/luat_netdrv_napt_udp.c

@@ -20,6 +20,25 @@ extern luat_netdrv_napt_ctx_t *g_napt_udp_ctx;
 #define NAPT_ETH_HDR_LEN sizeof(struct ethhdr)
 
 static uint8_t *udp_buff;
+// Incrementally update checksum when a 16-bit field (network order) changes.
+static inline uint16_t napt_chksum_replace_u16(uint16_t sum_net, uint16_t old_net, uint16_t new_net)
+{
+    uint32_t acc = (~lwip_ntohs(sum_net) & 0xFFFFU) + (~lwip_ntohs(old_net) & 0xFFFFU) + lwip_ntohs(new_net);
+    acc = (acc >> 16) + (acc & 0xFFFFU);
+    acc += (acc >> 16);
+    return lwip_htons((uint16_t)(~acc));
+}
+
+// Incrementally update checksum when a 32-bit field (network order) changes.
+static inline uint16_t napt_chksum_replace_u32(uint16_t sum_net, uint32_t old_net, uint32_t new_net)
+{
+    const uint16_t *old16 = (const uint16_t *)&old_net;
+    const uint16_t *new16 = (const uint16_t *)&new_net;
+    sum_net = napt_chksum_replace_u16(sum_net, old16[0], new16[0]);
+    sum_net = napt_chksum_replace_u16(sum_net, old16[1], new16[1]);
+    return sum_net;
+}
+
 __NETDRV_CODE_IN_RAM__ int luat_napt_udp_handle(napt_ctx_t *ctx)
 {
     uint16_t iphdr_len = (ctx->iphdr->_v_hl & 0x0F) * 4;
@@ -40,17 +59,24 @@ __NETDRV_CODE_IN_RAM__ int luat_napt_udp_handle(napt_ctx_t *ctx)
         if (ret == 0)
         {
             // 找到映射关系了!!! 修改目标ID
+            uint16_t old_dst_port = udp_hdr->dest;
+            uint32_t old_dst_ip = ip_hdr->dest.addr;
             udp_hdr->dest = mapping.inet_port;
 
-            // 修改目标地址到内网地址,并重新计算ip的checksum
+            // 修改目标地址到内网地址,并增量更新ip的checksum
+            uint16_t ip_sum = ip_hdr->_chksum;
             ip_hdr->dest.addr = mapping.inet_ip;
-            IPH_CHKSUM_SET(ip_hdr, 0);
-            IPH_CHKSUM_SET(ip_hdr, alg_iphdr_chksum((u16 *)ip_hdr, iphdr_len));
+            ip_sum = napt_chksum_replace_u32(ip_sum, old_dst_ip, mapping.inet_ip);
+            IPH_CHKSUM_SET(ip_hdr, ip_sum);
 
             // 重新计算icmp的checksum
             if (udp_hdr->chksum)
             {
-                udp_hdr->chksum = 0;
+                udp_hdr->chksum = napt_chksum_replace_u32(udp_hdr->chksum, old_dst_ip, mapping.inet_ip);
+                udp_hdr->chksum = napt_chksum_replace_u16(udp_hdr->chksum, old_dst_port, mapping.inet_port);
+            }
+            else
+            {
                 udp_hdr->chksum = alg_tcpudphdr_chksum(ip_hdr->src.addr,
                                                        ip_hdr->dest.addr,
                                                        IP_PROTO_UDP,
@@ -115,21 +141,28 @@ __NETDRV_CODE_IN_RAM__ int luat_napt_udp_handle(napt_ctx_t *ctx)
             return 0;
         }
         // 2. 修改信息
-        ip_hdr->src.addr = ip_addr_get_ip4_u32(&gw->netif->ip_addr);
+        uint16_t old_src_port = udp_hdr->src;
+        uint32_t old_src_ip = ip_hdr->src.addr;
+        uint32_t new_src_ip = ip_addr_get_ip4_u32(&gw->netif->ip_addr);
+        uint16_t ip_sum = ip_hdr->_chksum;
+        ip_hdr->src.addr = new_src_ip;
+        ip_sum = napt_chksum_replace_u32(ip_sum, old_src_ip, new_src_ip);
+        IPH_CHKSUM_SET(ip_hdr, ip_sum);
         udp_hdr->src = mapping.wnet_local_port;
         // 3. 与ICMP不同, 先计算IP的checksum
-        IPH_CHKSUM_SET(ip_hdr, 0);
-        IPH_CHKSUM_SET(ip_hdr, alg_iphdr_chksum((u16 *)ip_hdr, iphdr_len));
         // 4. 计算IP包的checksum
         if (udp_hdr->chksum)
         {
-            udp_hdr->chksum = 0;
+            udp_hdr->chksum = napt_chksum_replace_u32(udp_hdr->chksum, old_src_ip, new_src_ip);
+            udp_hdr->chksum = napt_chksum_replace_u16(udp_hdr->chksum, old_src_port, mapping.wnet_local_port);
+        }
+        else
+        {
             udp_hdr->chksum = alg_tcpudphdr_chksum(ip_hdr->src.addr,
                                                    ip_hdr->dest.addr,
                                                    IP_PROTO_UDP,
                                                    (u16 *)udp_hdr,
                                                    ntohs(ip_hdr->_len) - iphdr_len);
-            // udp_hdr->chksum = 0; // 强制不校验
         }
 
         // 发送出去