|
|
@@ -22,6 +22,25 @@ extern luat_netdrv_napt_ctx_t *g_napt_tcp_ctx;
|
|
|
#define u8 uint8_t
|
|
|
#define NAPT_ETH_HDR_LEN sizeof(struct ethhdr)
|
|
|
|
|
|
+// Incrementally update checksum when a 16-bit field (network order) changes.
|
|
|
+static inline uint16_t napt_chksum_replace_u16(uint16_t sum_net, uint16_t old_net, uint16_t new_net)
|
|
|
+{
|
|
|
+ uint32_t acc = (~lwip_ntohs(sum_net) & 0xFFFFU) + (~lwip_ntohs(old_net) & 0xFFFFU) + lwip_ntohs(new_net);
|
|
|
+ acc = (acc >> 16) + (acc & 0xFFFFU);
|
|
|
+ acc += (acc >> 16);
|
|
|
+ return lwip_htons((uint16_t)(~acc));
|
|
|
+}
|
|
|
+
|
|
|
+// Incrementally update checksum when a 32-bit field (network order) changes.
|
|
|
+static inline uint16_t napt_chksum_replace_u32(uint16_t sum_net, uint32_t old_net, uint32_t new_net)
|
|
|
+{
|
|
|
+ const uint16_t *old16 = (const uint16_t *)&old_net;
|
|
|
+ const uint16_t *new16 = (const uint16_t *)&new_net;
|
|
|
+ sum_net = napt_chksum_replace_u16(sum_net, old16[0], new16[0]);
|
|
|
+ sum_net = napt_chksum_replace_u16(sum_net, old16[1], new16[1]);
|
|
|
+ return sum_net;
|
|
|
+}
|
|
|
+
|
|
|
__NETDRV_CODE_IN_RAM__ int luat_napt_tcp_handle(napt_ctx_t* ctx) {
|
|
|
uint16_t iphdr_len = (ctx->iphdr->_v_hl & 0x0F) * 4;
|
|
|
struct ip_hdr* ip_hdr = ctx->iphdr;
|
|
|
@@ -44,22 +63,32 @@ __NETDRV_CODE_IN_RAM__ int luat_napt_tcp_handle(napt_ctx_t* ctx) {
|
|
|
ret = luat_netdrv_napt_tcp_wan2lan(ctx, &mapping, g_napt_tcp_ctx);
|
|
|
if (ret == 0) {
|
|
|
// 修改目标端口
|
|
|
+ uint16_t old_dst_port = tcp_hdr->dest;
|
|
|
+ uint32_t old_dst_ip = ip_hdr->dest.addr;
|
|
|
+ uint16_t ip_sum = ip_hdr->_chksum;
|
|
|
+ uint16_t tcp_sum = tcp_hdr->chksum;
|
|
|
tcp_hdr->dest = mapping.inet_port;
|
|
|
|
|
|
- // 修改目标地址到内网地址,并重新计算ip的checksu
|
|
|
+ // 修改目标地址到内网地址,并增量更新ip的checksum
|
|
|
ip_hdr->dest.addr = mapping.inet_ip;
|
|
|
- ip_hdr->_chksum = 0;
|
|
|
- ip_hdr->_chksum = alg_iphdr_chksum((u16 *)ip_hdr, iphdr_len);
|
|
|
+ ip_sum = napt_chksum_replace_u32(ip_sum, old_dst_ip, mapping.inet_ip);
|
|
|
+ IPH_CHKSUM_SET(ip_hdr, ip_sum);
|
|
|
|
|
|
// 重新计算icmp的checksum
|
|
|
- // if (tcp_hdr->chksum) {
|
|
|
- tcp_hdr->chksum = 0;
|
|
|
+ if (tcp_sum)
|
|
|
+ {
|
|
|
+ tcp_sum = napt_chksum_replace_u32(tcp_sum, old_dst_ip, mapping.inet_ip);
|
|
|
+ tcp_sum = napt_chksum_replace_u16(tcp_sum, old_dst_port, mapping.inet_port);
|
|
|
+ tcp_hdr->chksum = tcp_sum;
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
tcp_hdr->chksum = alg_tcpudphdr_chksum(ip_hdr->src.addr,
|
|
|
- ip_hdr->dest.addr,
|
|
|
- IP_PROTO_TCP,
|
|
|
- (u16 *)tcp_hdr,
|
|
|
- ntohs(ip_hdr->_len) - iphdr_len);
|
|
|
- // }
|
|
|
+ ip_hdr->dest.addr,
|
|
|
+ IP_PROTO_TCP,
|
|
|
+ (u16 *)tcp_hdr,
|
|
|
+ ntohs(ip_hdr->_len) - iphdr_len);
|
|
|
+ }
|
|
|
|
|
|
// 如果是ETH包, 那还需要修改源MAC和目标MAC
|
|
|
if (ctx->eth) {
|
|
|
@@ -113,20 +142,31 @@ __NETDRV_CODE_IN_RAM__ int luat_napt_tcp_handle(napt_ctx_t* ctx) {
|
|
|
it_map = &mapping;
|
|
|
|
|
|
// 2. 修改信息
|
|
|
- ip_hdr->src.addr = ip_addr_get_ip4_u32(&gw->netif->ip_addr);
|
|
|
+ uint16_t old_src_port = tcp_hdr->src;
|
|
|
+ uint32_t old_src_ip = ip_hdr->src.addr;
|
|
|
+ uint32_t new_src_ip = ip_addr_get_ip4_u32(&gw->netif->ip_addr);
|
|
|
+ uint16_t ip_sum = ip_hdr->_chksum;
|
|
|
+ uint16_t tcp_sum = tcp_hdr->chksum;
|
|
|
+ ip_hdr->src.addr = new_src_ip;
|
|
|
+ ip_sum = napt_chksum_replace_u32(ip_sum, old_src_ip, new_src_ip);
|
|
|
+ IPH_CHKSUM_SET(ip_hdr, ip_sum);
|
|
|
tcp_hdr->src = it_map->wnet_local_port;
|
|
|
// 3. 与ICMP不同, 先计算IP的checksum
|
|
|
- ip_hdr->_chksum = 0;
|
|
|
- ip_hdr->_chksum = alg_iphdr_chksum((u16 *)ip_hdr, iphdr_len);
|
|
|
// 4. 计算IP包的checksum
|
|
|
- // if (tcp_hdr->chksum) {
|
|
|
- tcp_hdr->chksum = 0;
|
|
|
+ if (tcp_sum)
|
|
|
+ {
|
|
|
+ tcp_sum = napt_chksum_replace_u32(tcp_sum, old_src_ip, new_src_ip);
|
|
|
+ tcp_sum = napt_chksum_replace_u16(tcp_sum, old_src_port, it_map->wnet_local_port);
|
|
|
+ tcp_hdr->chksum = tcp_sum;
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
tcp_hdr->chksum = alg_tcpudphdr_chksum(ip_hdr->src.addr,
|
|
|
ip_hdr->dest.addr,
|
|
|
IP_PROTO_TCP,
|
|
|
(u16 *)tcp_hdr,
|
|
|
ntohs(ip_hdr->_len) - iphdr_len);
|
|
|
- // }
|
|
|
+ }
|
|
|
|
|
|
// 发送出去
|
|
|
if (gw && gw->dataout && gw->netif) {
|