Bladeren bron

fix: 修改fft库位置,并增加fft内存标记,用于提高fft速度

zengshuai 7 maanden geleden
bovenliggende
commit
4ce9a5c078
3 gewijzigde bestanden met toevoegingen van 340 en 0 verwijderingen
  1. 151 0
      components/fft/binding/luat_lib_fft.c
  2. 31 0
      components/fft/inc/fft_core.h
  3. 158 0
      components/fft/src/fft_core.c

+ 151 - 0
components/fft/binding/luat_lib_fft.c

@@ -0,0 +1,151 @@
+#include "luat_base.h"
+#include "luat_mem.h"
+
+#include <math.h>
+#include <string.h>
+
+#include "fft_core.h"
+
+#define LUAT_LOG_TAG "fft"
+#include "luat_log.h"
+
+#include "luat_conf_bsp.h"
+#include "rotable2.h"
+#include "luat_zbuff.h"
+
+// helper: read float array from lua table (1-based)
+static int read_lua_array_float(lua_State* L, int idx, float* out, int n) {
+    for (int i = 0; i < n; i++) {
+        lua_rawgeti(L, idx, i+1);
+        if (!lua_isnumber(L, -1)) { lua_pop(L, 1); return -1; }
+        out[i] = (float)lua_tonumber(L, -1);
+        lua_pop(L, 1);
+    }
+    return 0;
+}
+
+// helper: write float array into lua table (1-based)
+static void write_lua_array_float(lua_State* L, float* a, int n) {
+    lua_createtable(L, n, 0);
+    for (int i = 0; i < n; i++) {
+        lua_pushnumber(L, a[i]);
+        lua_rawseti(L, -2, i+1);
+    }
+}
+
+static int l_fft_generate_twiddles(lua_State* L) {
+    int N = luaL_checkinteger(L, 1);
+    if (N <= 1 || (N & (N-1))) return luaL_error(L, "N must be power of 2");
+    int half = N/2;
+    float* Wc = luat_heap_malloc(sizeof(float)*half);
+    float* Ws = luat_heap_malloc(sizeof(float)*half);
+    if (!Wc || !Ws) {
+        if (Wc) luat_heap_free(Wc);
+        if (Ws) luat_heap_free(Ws);
+        return luaL_error(L, "no memory");
+    }
+    fft_generate_twiddles(N, Wc, Ws);
+    write_lua_array_float(L, Wc, half);
+    write_lua_array_float(L, Ws, half);
+    luat_heap_free(Wc);
+    luat_heap_free(Ws);
+    return 2;
+}
+
+static int l_fft_run(lua_State* L) {
+    int N = luaL_checkinteger(L, 3);
+    if (N <= 1 || (N & (N-1))) return luaL_error(L, "N must be power of 2");
+
+    float *r = NULL, *im = NULL, *Wc = NULL, *Ws = NULL;
+    int r_free = 0, im_free = 0, wc_free = 0, ws_free = 0;
+
+    // real
+    luat_zbuff_t* zb = (luat_zbuff_t*)luaL_testudata(L, 1, LUAT_ZBUFF_TYPE);
+    if (zb) { r = (float*)zb->addr; }
+    else { r = luat_heap_malloc(sizeof(float)*N); r_free = 1; if (!r) return luaL_error(L, "no memory"); if (read_lua_array_float(L, 1, r, N)) { if (r_free) luat_heap_free(r); return luaL_error(L, "real must be number array or zbuff"); } }
+    // imag
+    zb = (luat_zbuff_t*)luaL_testudata(L, 2, LUAT_ZBUFF_TYPE);
+    if (zb) { im = (float*)zb->addr; }
+    else { im = luat_heap_malloc(sizeof(float)*N); im_free = 1; if (!im) { if (r_free) luat_heap_free(r); return luaL_error(L, "no memory"); } if (read_lua_array_float(L, 2, im, N)) { if (r_free) luat_heap_free(r); if (im_free) luat_heap_free(im); return luaL_error(L, "imag must be number array or zbuff"); } }
+    // W_real
+    zb = (luat_zbuff_t*)luaL_testudata(L, 4, LUAT_ZBUFF_TYPE);
+    if (zb) { Wc = (float*)zb->addr; }
+    else { Wc = luat_heap_malloc(sizeof(float)*(N/2)); wc_free = 1; if (!Wc) { if (r_free) luat_heap_free(r); if (im_free) luat_heap_free(im); return luaL_error(L, "no memory"); } if (read_lua_array_float(L, 4, Wc, N/2)) { if (r_free) luat_heap_free(r); if (im_free) luat_heap_free(im); if (wc_free) luat_heap_free(Wc); return luaL_error(L, "W_real must be number array or zbuff"); } }
+    // W_imag
+    zb = (luat_zbuff_t*)luaL_testudata(L, 5, LUAT_ZBUFF_TYPE);
+    if (zb) { Ws = (float*)zb->addr; }
+    else { Ws = luat_heap_malloc(sizeof(float)*(N/2)); ws_free = 1; if (!Ws) { if (r_free) luat_heap_free(r); if (im_free) luat_heap_free(im); if (wc_free) luat_heap_free(Wc); return luaL_error(L, "no memory"); } if (read_lua_array_float(L, 5, Ws, N/2)) { if (r_free) luat_heap_free(r); if (im_free) luat_heap_free(im); if (wc_free) luat_heap_free(Wc); if (ws_free) luat_heap_free(Ws); return luaL_error(L, "W_imag must be number array or zbuff"); } }
+
+    fft_run_inplace(r, im, N, Wc, Ws);
+
+    // if input was table, write back
+    if (!luaL_testudata(L, 1, LUAT_ZBUFF_TYPE)) { lua_settop(L, 2); for (int i = 0; i < N; i++) { lua_pushnumber(L, r[i]); lua_rawseti(L, 1, i+1); } }
+    if (!luaL_testudata(L, 2, LUAT_ZBUFF_TYPE)) { for (int i = 0; i < N; i++) { lua_pushnumber(L, im[i]); lua_rawseti(L, 2, i+1); } }
+
+    if (r_free) luat_heap_free(r);
+    if (im_free) luat_heap_free(im);
+    if (wc_free) luat_heap_free(Wc);
+    if (ws_free) luat_heap_free(Ws);
+    return 0;
+}
+
+static int l_fft_ifft(lua_State* L) {
+    int N = luaL_checkinteger(L, 3);
+    if (N <= 1 || (N & (N-1))) return luaL_error(L, "N must be power of 2");
+
+    float *r = NULL, *im = NULL, *Wc = NULL, *Ws = NULL;
+    int r_free = 0, im_free = 0, wc_free = 0, ws_free = 0;
+    luat_zbuff_t* zb = NULL;
+    zb = (luat_zbuff_t*)luaL_testudata(L, 1, LUAT_ZBUFF_TYPE);
+    if (zb) { r = (float*)zb->addr; } else { r = luat_heap_malloc(sizeof(float)*N); r_free = 1; if (!r) return luaL_error(L, "no memory"); if (read_lua_array_float(L, 1, r, N)) { if (r_free) luat_heap_free(r); return luaL_error(L, "real must be number array or zbuff"); } }
+    zb = (luat_zbuff_t*)luaL_testudata(L, 2, LUAT_ZBUFF_TYPE);
+    if (zb) { im = (float*)zb->addr; } else { im = luat_heap_malloc(sizeof(float)*N); im_free = 1; if (!im) { if (r_free) luat_heap_free(r); return luaL_error(L, "no memory"); } if (read_lua_array_float(L, 2, im, N)) { if (r_free) luat_heap_free(r); if (im_free) luat_heap_free(im); return luaL_error(L, "imag must be number array or zbuff"); } }
+    zb = (luat_zbuff_t*)luaL_testudata(L, 4, LUAT_ZBUFF_TYPE);
+    if (zb) { Wc = (float*)zb->addr; } else { Wc = luat_heap_malloc(sizeof(float)*(N/2)); wc_free = 1; if (!Wc) { if (r_free) luat_heap_free(r); if (im_free) luat_heap_free(im); return luaL_error(L, "no memory"); } if (read_lua_array_float(L, 4, Wc, N/2)) { if (r_free) luat_heap_free(r); if (im_free) luat_heap_free(im); if (wc_free) luat_heap_free(Wc); return luaL_error(L, "W_real must be number array or zbuff"); } }
+    zb = (luat_zbuff_t*)luaL_testudata(L, 5, LUAT_ZBUFF_TYPE);
+    if (zb) { Ws = (float*)zb->addr; } else { Ws = luat_heap_malloc(sizeof(float)*(N/2)); ws_free = 1; if (!Ws) { if (r_free) luat_heap_free(r); if (im_free) luat_heap_free(im); if (wc_free) luat_heap_free(Wc); return luaL_error(L, "no memory"); } if (read_lua_array_float(L, 5, Ws, N/2)) { if (r_free) luat_heap_free(r); if (im_free) luat_heap_free(im); if (wc_free) luat_heap_free(Wc); if (ws_free) luat_heap_free(Ws); return luaL_error(L, "W_imag must be number array or zbuff"); } }
+
+    ifft_run_inplace(r, im, N, Wc, Ws);
+
+    if (!luaL_testudata(L, 1, LUAT_ZBUFF_TYPE)) { lua_settop(L, 2); for (int i = 0; i < N; i++) { lua_pushnumber(L, r[i]); lua_rawseti(L, 1, i+1); } }
+    if (!luaL_testudata(L, 2, LUAT_ZBUFF_TYPE)) { for (int i = 0; i < N; i++) { lua_pushnumber(L, im[i]); lua_rawseti(L, 2, i+1); } }
+
+    if (r_free) luat_heap_free(r);
+    if (im_free) luat_heap_free(im);
+    if (wc_free) luat_heap_free(Wc);
+    if (ws_free) luat_heap_free(Ws);
+    return 0;
+}
+
+static int l_fft_integral(lua_State* L) {
+    int n = luaL_checkinteger(L, 3);
+    float df = (float)luaL_checknumber(L, 4);
+    if (n <= 1 || (n & (n-1))) return luaL_error(L, "n must be power of 2");
+    float *r = NULL, *im = NULL; int r_free = 0, im_free = 0;
+    luat_zbuff_t* zb = NULL;
+    zb = (luat_zbuff_t*)luaL_testudata(L, 1, LUAT_ZBUFF_TYPE);
+    if (zb) { r = (float*)zb->addr; } else { r = luat_heap_malloc(sizeof(float)*n); r_free = 1; if (!r) return luaL_error(L, "no memory"); if (read_lua_array_float(L, 1, r, n)) { if (r_free) luat_heap_free(r); return luaL_error(L, "real must be number array or zbuff"); } }
+    zb = (luat_zbuff_t*)luaL_testudata(L, 2, LUAT_ZBUFF_TYPE);
+    if (zb) { im = (float*)zb->addr; } else { im = luat_heap_malloc(sizeof(float)*n); im_free = 1; if (!im) { if (r_free) luat_heap_free(r); return luaL_error(L, "no memory"); } if (read_lua_array_float(L, 2, im, n)) { if (r_free) luat_heap_free(r); if (im_free) luat_heap_free(im); return luaL_error(L, "imag must be number array or zbuff"); } }
+    fft_integral_inplace(r, im, n, df);
+    if (!luaL_testudata(L, 1, LUAT_ZBUFF_TYPE)) { lua_settop(L, 2); for (int i = 0; i < n; i++) { lua_pushnumber(L, r[i]); lua_rawseti(L, 1, i+1); } }
+    if (!luaL_testudata(L, 2, LUAT_ZBUFF_TYPE)) { for (int i = 0; i < n; i++) { lua_pushnumber(L, im[i]); lua_rawseti(L, 2, i+1); } }
+    if (r_free) luat_heap_free(r);
+    if (im_free) luat_heap_free(im);
+    return 0;
+}
+
+static const rotable_Reg_t reg_fft[] = {
+    {"generate_twiddles", ROREG_FUNC(l_fft_generate_twiddles)},
+    {"run",               ROREG_FUNC(l_fft_run)},
+    {"ifft",              ROREG_FUNC(l_fft_ifft)},
+    {"fft_integral",      ROREG_FUNC(l_fft_integral)},
+    {NULL,                 ROREG_INT(0)}
+};
+
+LUAMOD_API int luaopen_fft(lua_State *L) {
+    luat_newlib2(L, reg_fft);
+    return 1;
+}
+
+

+ 31 - 0
components/fft/inc/fft_core.h

@@ -0,0 +1,31 @@
+#pragma once
+
+#include <stddef.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+// Generate twiddle factors cos/sin for size N (N must be power of 2)
+// Wc and Ws must have length at least N/2 (float32)
+void fft_generate_twiddles(int N, float* Wc, float* Ws);
+
+// In-place iterative radix-2 FFT
+// If Wc/Ws are NULL, twiddles are computed on the fly
+void fft_run_inplace(float* real, float* imag, int N,
+                     const float* Wc, const float* Ws);
+
+// In-place IFFT with 1/N normalization
+void ifft_run_inplace(float* real, float* imag, int N,
+                      const float* Wc, const float* Ws);
+
+// Frequency-domain integral: X(ω) -> X(ω)/(jω)
+// n is FFT size, df is frequency resolution fs/n
+void fft_integral_inplace(float* xr, float* xi, int n, float df);
+
+#ifdef __cplusplus
+}
+#endif
+
+

+ 158 - 0
components/fft/src/fft_core.c

@@ -0,0 +1,158 @@
+#include "fft_core.h"
+#include "luat_base.h"
+#include <math.h>
+#include <string.h>
+#include <stdint.h>
+#ifndef M_PI
+#define M_PI 3.14159265358979323846
+#endif
+
+#ifndef __USER_FUNC_IN_RAM__
+#define __USER_FUNC_IN_RAM__ 
+#endif
+
+__USER_FUNC_IN_RAM__ static inline int reverse_bits(int x, int bits) {
+    int y = 0;
+    for (int i = 0; i < bits; i++) {
+        y = (y << 1) | (x & 1);
+        x >>= 1;
+    }
+    return y;
+}
+
+__USER_FUNC_IN_RAM__ void fft_generate_twiddles(int N, float* Wc, float* Ws) {
+    for (int k = 0; k < N/2; k++) {
+        float angle = (float)(-2.0 * M_PI * (double)k / (double)N);
+        Wc[k] = (float)cos(angle);
+        Ws[k] = (float)sin(angle);
+    }
+}
+
+__USER_FUNC_IN_RAM__ static void fft_inplace_core(float* real, float* imag, int N, int inverse,
+                             const float* Wc, const float* Ws) {
+    // bit-reverse permutation 位逆序置换
+    // 计算需要的位数来表示所有索引
+    int bits = 0; while ((1 << bits) < N) bits++;
+    
+    // 对所有元素进行位逆序重排
+    for (int i = 0; i < N; i++) {
+        int j = reverse_bits(i, bits);  // 获取位逆序后的索引
+        if (j > i) {  // 只交换一次,避免重复交换
+            // 交换实部
+            float tr = real[i]; real[i] = real[j]; real[j] = tr;
+            // 交换虚部
+            float ti = imag[i]; imag[i] = imag[j]; imag[j] = ti;
+        }
+    }
+
+    // iterative stages 迭代阶段 - Cooley-Tukey算法的主要循环
+    // len从2开始,每次翻倍,直到N
+    for (int len = 2; len <= N; len <<= 1) {
+        int half = len >> 1;           // 当前阶段的半长度
+        int step = N / len;            // twiddle step 旋转因子步长
+        
+        // 处理每个长度为len的子序列
+        for (int i = 0; i < N; i += len) {
+            // 对当前子序列进行蝶形运算
+            for (int j = 0; j < half; j++) {
+                int idx = j * step;    // 旋转因子索引
+                float wr, wi;          // 旋转因子的实部和虚部
+                
+                // 获取旋转因子
+                if (Wc && Ws) {
+                    // 使用预计算的旋转因子表
+                    wr = Wc[idx];
+                    wi = Ws[idx];
+                } else {
+                    // 实时计算旋转因子
+                    float angle = (float)(-2.0 * M_PI * (double)idx / (double)N);
+                    wr = (float)cos(angle);
+                    wi = (float)sin(angle);
+                }
+                
+                // 如果是逆变换,需要共轭旋转因子
+                if (inverse) wi = -wi;
+                
+                // 蝶形运算的两个数据点索引
+                int p = i + j;         // 上半部分索引
+                int q = p + half;      // 下半部分索引
+                
+                // 计算旋转后的值:W * X[q]
+                float tr = wr * real[q] - wi * imag[q];  // 复数乘法实部
+                float ti = wr * imag[q] + wi * real[q];  // 复数乘法虚部
+                
+                // 保存原始值
+                float ur = real[p];
+                float ui = imag[p];
+                
+                // 蝶形运算:更新两个数据点
+                real[p] = ur + tr;     // X[p] = X[p] + W*X[q]
+                imag[p] = ui + ti;
+                real[q] = ur - tr;     // X[q] = X[p] - W*X[q]
+                imag[q] = ui - ti;
+            }
+        }
+    }
+}
+
+__USER_FUNC_IN_RAM__ void fft_run_inplace(float* real, float* imag, int N,
+                     const float* Wc, const float* Ws) {
+    fft_inplace_core(real, imag, N, 0, Wc, Ws);
+}
+
+__USER_FUNC_IN_RAM__ void ifft_run_inplace(float* real, float* imag, int N,
+                      const float* Wc, const float* Ws) {
+    fft_inplace_core(real, imag, N, 1, Wc, Ws);
+    // 1/N normalization
+    float invN = (float)(1.0 / (double)N);
+    for (int i = 0; i < N; i++) {
+        real[i] *= invN;
+        imag[i] *= invN;
+    }
+}
+
+__USER_FUNC_IN_RAM__ void fft_integral_inplace(float* xr, float* xi, int n, float df) {
+    // 计算角频率步长 Calculate angular frequency step
+    const float two_pi_df = (float)(2.0 * M_PI) * df;
+    
+    // positive freqs (exclude DC) 正频率部分(排除直流分量)
+    // 处理正频率分量,通过除以jω实现时域积分
+    for (int i = 1; i < n/2; i++) {
+        float omega = two_pi_df * (float)i; // 当前频率的角频率
+        if (omega != 0.0) {
+            // 保存原始值
+            float xr0 = xr[i];
+            float xi0 = xi[i];
+            // 执行 X(ω) / (jω) = (xr + j*xi) / (j*ω) = xi/ω - j*xr/ω
+            xr[i] =  xi0 / omega;  // 实部:xi/ω
+            xi[i] = -xr0 / omega;  // 虚部:-xr/ω
+        } else {
+            // 频率为0时设为0(避免除零)
+            xr[i] = 0.0; xi[i] = 0.0;
+        }
+    }
+    
+    // negative freqs 负频率部分
+    // 处理负频率分量
+    for (int i = n/2 + 1; i < n; i++) {
+        int k = i - n; // negative bin index 负频率索引
+        float omega = two_pi_df * (float)k; // 负频率的角频率
+        if (omega != 0.0) {
+            // 保存原始值
+            float xr0 = xr[i];
+            float xi0 = xi[i];
+            // 执行 X(ω) / (jω) 变换
+            xr[i] =  xi0 / omega;  // 实部:xi/ω
+            xi[i] = -xr0 / omega;  // 虚部:-xr/ω
+        } else {
+            // 频率为0时设为0(避免除零)
+            xr[i] = 0.0; xi[i] = 0.0;
+        }
+    }
+    
+    // DC 直流分量
+    // 直流分量(ω=0)在积分后应为0
+    xr[0] = 0.0; xi[0] = 0.0;
+}
+
+