luat_lib_fft.c 28 KB


  1. /*
  2. @module fft
  3. @summary 快速傅里叶变换(FFT/IFFT),支持 float32 与 q15 定点内核
  4. @version 0.2
  5. @date 2025.08
  6. @demo fft
  7. @tag LUAT_USE_FFT
  8. @usage
  9. -- 模块通常作为内置库,无需 require,直接调用
  10. -- 示例见各 API 注释以及仓库 demo/test 脚本
  11. */
  12. #include "luat_base.h"
  13. #include "luat_mem.h"
  14. #include <math.h>
  15. #include <string.h>
  16. #include "fft_core.h"
  17. #define LUAT_LOG_TAG "fft"
  18. #include "luat_log.h"
  19. #include "luat_conf_bsp.h"
  20. #include "luat_zbuff.h"
  21. #include "rotable2.h"
  22. // Q15 内核头文件(当前版本已实现)
  23. #include "fft_core_q15.h"
  24. // helper: read float array from lua table (1-based)
  25. static int read_lua_array_float(lua_State* L, int idx, float* out, int n)
  26. {
  27. for (int i = 0; i < n; i++) {
  28. lua_rawgeti(L, idx, i + 1);
  29. if (!lua_isnumber(L, -1)) {
  30. lua_pop(L, 1);
  31. return -1;
  32. }
  33. out[i] = (float)lua_tonumber(L, -1);
  34. lua_pop(L, 1);
  35. }
  36. return 0;
  37. }
  38. // helper: write float array into lua table (1-based)
  39. static void write_lua_array_float(lua_State* L, float* a, int n)
  40. {
  41. lua_createtable(L, n, 0);
  42. for (int i = 0; i < n; i++) {
  43. lua_pushnumber(L, a[i]);
  44. lua_rawseti(L, -2, i + 1);
  45. }
  46. }
  47. /*
  48. 生成 float32 旋转因子
  49. @api fft.generate_twiddles(N)
  50. @int N 点数,必须为 2 的幂
  51. @return table Wc, table Ws 两个 Lua 数组(长度 N/2),分别为 cos 与 -sin
  52. @usage
  53. local N = 2048
  54. local Wc, Ws = fft.generate_twiddles(N)
  55. */
  56. static int l_fft_generate_twiddles(lua_State* L)
  57. {
  58. int N = luaL_checkinteger(L, 1);
  59. if (N <= 1 || (N & (N - 1)))
  60. return luaL_error(L, "N must be power of 2");
  61. int half = N / 2;
  62. float* Wc = luat_heap_malloc(sizeof(float) * half);
  63. float* Ws = luat_heap_malloc(sizeof(float) * half);
  64. if (!Wc || !Ws) {
  65. if (Wc)
  66. luat_heap_free(Wc);
  67. if (Ws)
  68. luat_heap_free(Ws);
  69. return luaL_error(L, "no memory");
  70. }
  71. luat_fft_generate_twiddles(N, Wc, Ws);
  72. write_lua_array_float(L, Wc, half);
  73. write_lua_array_float(L, Ws, half);
  74. luat_heap_free(Wc);
  75. luat_heap_free(Ws);
  76. return 2;
  77. }
  78. // 生成 Q15 旋转因子到 zbuff(长度要求:N/2 * 2 字节)
  79. // 调用方式:fft.generate_twiddles_q15_to_zbuff(N, Wc_zbuff, Ws_zbuff)
  80. /*
  81. 生成 q15 旋转因子到 zbuff(零浮点)
  82. @api fft.generate_twiddles_q15_to_zbuff(N, Wc_zb, Ws_zb)
  83. @int N 点数,必须为 2 的幂
  84. @zbuff Wc_zb 输出缓冲,长度至少为 (N/2)*2 字节,存放 int16 Q15 的 cos
  85. @zbuff Ws_zb 输出缓冲,长度至少为 (N/2)*2 字节,存放 int16 Q15 的 -sin(前向)
  86. @return nil 无返回值,结果写入传入的 zbuff
  87. @usage
  88. local N = 2048
  89. local Wc_q15 = zbuff.create((N//2)*2)
  90. local Ws_q15 = zbuff.create((N//2)*2)
  91. fft.generate_twiddles_q15_to_zbuff(N, Wc_q15, Ws_q15)
  92. */
  93. static int l_fft_generate_twiddles_q15_to_zbuff(lua_State* L)
  94. {
  95. // 使用整型查表(度数,1度分辨率,缩放256)生成近似Q15旋转因子
  96. static const int16_t SIN_TABLE256[91] = {
  97. 0, 4, 9, 13, 18, 22, 27, 31, 36, 40, 44, 49, 53, 58, 62, 66, 71, 75, 79, 83,
  98. 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 132, 136, 139, 143, 147, 150, 154, 158, 161,
  99. 165, 168, 171, 175, 178, 181, 184, 187, 190, 193, 196, 199, 202, 204, 207, 210, 212, 215, 217, 219,
  100. 222, 224, 226, 228, 230, 232, 234, 236, 237, 239, 241, 242, 243, 245, 246, 247, 248, 249, 250, 251,
  101. 252, 253, 254, 254, 255, 255, 255, 256, 256, 256, 256
  102. };
  103. int N = luaL_checkinteger(L, 1);
  104. if (N <= 1 || (N & (N - 1)))
  105. return luaL_error(L, "N 必须为 2 的幂");
  106. luat_zbuff_t* zbWc = (luat_zbuff_t*)luaL_testudata(L, 2, LUAT_ZBUFF_TYPE);
  107. luat_zbuff_t* zbWs = (luat_zbuff_t*)luaL_testudata(L, 3, LUAT_ZBUFF_TYPE);
  108. if (!zbWc || !zbWs)
  109. return luaL_error(L, "Wc/Ws 需为 zbuff");
  110. int need = (N / 2) * 2;
  111. if ((int)zbWc->len < need || (int)zbWs->len < need)
  112. return luaL_error(L, "zbuff 太小");
  113. int16_t* Wc = (int16_t*)zbWc->addr;
  114. int16_t* Ws = (int16_t*)zbWs->addr;
  115. for (int k = 0; k < N / 2; k++) {
  116. // angle = 360 * k / N (度),四舍五入
  117. int deg = (int)((int64_t)k * 360 + (N / 2)) / N;
  118. // cos 与 -sin,放大到 Q15(256<<7=32768,钳到 32767)
  119. int s256;
  120. int d = deg;
  121. int neg = 0;
  122. if (d < 0) {
  123. d = -d + 180;
  124. }
  125. d %= 360;
  126. if (d >= 180) {
  127. d -= 180;
  128. neg = 1;
  129. }
  130. if (d <= 90)
  131. s256 = SIN_TABLE256[d];
  132. else
  133. s256 = SIN_TABLE256[180 - d];
  134. int c256 = 0; // cos = sin(deg+90)
  135. int d2 = deg + 90;
  136. neg = 0;
  137. d = d2;
  138. if (d < 0) {
  139. d = -d + 180;
  140. }
  141. d %= 360;
  142. if (d >= 180) {
  143. d -= 180;
  144. neg = 1;
  145. }
  146. if (d <= 90)
  147. c256 = SIN_TABLE256[d];
  148. else
  149. c256 = SIN_TABLE256[180 - d];
  150. if (neg)
  151. c256 = -c256;
  152. int32_t wc = (int32_t)c256 << 7;
  153. if (wc > 32767)
  154. wc = 32767;
  155. if (wc < -32768)
  156. wc = -32768;
  157. // -sin:复用 s256 并加符号
  158. int sgn = 0;
  159. d = deg;
  160. if (d < 0) {
  161. d = -d + 180;
  162. }
  163. d %= 360;
  164. if (d >= 180) {
  165. d -= 180;
  166. sgn = 1;
  167. }
  168. int32_t ws = (int32_t)(sgn ? -s256 : s256) << 7;
  169. if (ws > 32767)
  170. ws = 32767;
  171. if (ws < -32768)
  172. ws = -32768;
  173. Wc[k] = (int16_t)wc;
  174. Ws[k] = (int16_t)ws;
  175. }
  176. return 0;
  177. }
  178. /*
  179. 原地 FFT 计算
  180. @api fft.run(real, imag, N, Wc, Ws[, opts])
  181. @param real 实部容器:
  182. - float32 路径:Lua 数组或 zbuff(float32)
  183. - q15 路径:zbuff(int16)。当 opts.core="q15" 且 opts.input_format 为 "u12"/"u16"/"s16" 时生效
  184. @param imag 虚部容器:同 real。可为 nil(视为全 0)
  185. @int N 点数,2 的幂
  186. @param Wc 旋转因子 cos:
  187. - float32 路径:Lua 数组或 zbuff(float32)
  188. - q15 路径:zbuff(int16),推荐配合 fft.generate_twiddles_q15_to_zbuff 生成
  189. @param Ws 旋转因子 -sin:同 Wc;IFFT 建议传入共轭版本以避免内层符号乘
  190. @table [opts]
  191. - core: "f32" | "q15"(默认 "f32")
  192. * "f32": 浮点内核,精度高(32位),计算稳定,适合精密分析
  193. * "q15": 定点内核,速度快(16位整数),内存省,适合实时处理但精度略低
  194. - input_format: "f32" | "u12" | "u16" | "s16"(q15 时必填其一)
  195. * "f32": 标准浮点输入,适用于已处理的信号数据
  196. * "u12": 12位无符号整数(0~4095),常见于ADC采样,自动去直流分量
  197. * "u16": 16位无符号整数(0~65535),适用于高精度ADC或预处理数据
  198. * "s16": 16位有符号整数(-32768~32767),适用于已去直流的差分信号
  199. @return nil 就地修改 real/imag
  200. @usage
  201. -- f32 路径示例(zbuff float32)
  202. local N=2048
  203. local real=zbuff.create(N*4); local imag=zbuff.create(N*4)
  204. local Wc,Ws=fft.generate_twiddles(N)
  205. fft.run(real, imag, N, Wc, Ws)
  206. -- q15 路径示例(U12 整数输入)
  207. local N=2048
  208. local real_i16=zbuff.create(N*2); local imag_i16=zbuff.create(N*2)
  209. local Wc_q15=zbuff.create((N//2)*2); local Ws_q15=zbuff.create((N//2)*2)
  210. fft.generate_twiddles_q15_to_zbuff(N, Wc_q15, Ws_q15)
  211. -- 写入 U12 数据到 real_i16 后:
  212. fft.run(real_i16, imag_i16, N, Wc_q15, Ws_q15, {core="q15", input_format="u12"})
  213. */
  214. static int l_fft_run(lua_State* L)
  215. {
  216. int N = luaL_checkinteger(L, 3);
  217. if (N <= 1 || (N & (N - 1)))
  218. return luaL_error(L, "N 必须为 2 的幂");
  219. float *r = NULL, *im = NULL, *Wc = NULL, *Ws = NULL;
  220. int r_free = 0, im_free = 0, wc_free = 0, ws_free = 0;
  221. // 可选参数解析(opts):当前版本仅支持 core/input_format(其余项作为未来优化)
  222. const char* core = "f32"; // "f32" | "q15"
  223. const char* input_format = "f32"; // "f32"|"u12"|"u16"|"s16"
  224. // real
  225. luat_zbuff_t* zb = (luat_zbuff_t*)luaL_testudata(L, 1, LUAT_ZBUFF_TYPE);
  226. if (zb) {
  227. r = (float*)zb->addr;
  228. } else {
  229. r = luat_heap_malloc(sizeof(float) * N);
  230. r_free = 1;
  231. if (!r)
  232. return luaL_error(L, "no memory");
  233. if (read_lua_array_float(L, 1, r, N)) {
  234. if (r_free)
  235. luat_heap_free(r);
  236. return luaL_error(L, "real must be number array or zbuff");
  237. }
  238. }
  239. // imag
  240. zb = (luat_zbuff_t*)luaL_testudata(L, 2, LUAT_ZBUFF_TYPE);
  241. if (zb) {
  242. im = (float*)zb->addr;
  243. } else {
  244. im = luat_heap_malloc(sizeof(float) * N);
  245. im_free = 1;
  246. if (!im) {
  247. if (r_free)
  248. luat_heap_free(r);
  249. return luaL_error(L, "no memory");
  250. }
  251. if (read_lua_array_float(L, 2, im, N)) {
  252. if (r_free)
  253. luat_heap_free(r);
  254. if (im_free)
  255. luat_heap_free(im);
  256. return luaL_error(L, "imag must be number array or zbuff");
  257. }
  258. }
  259. // W_real
  260. zb = (luat_zbuff_t*)luaL_testudata(L, 4, LUAT_ZBUFF_TYPE);
  261. if (zb) {
  262. Wc = (float*)zb->addr;
  263. } else {
  264. Wc = luat_heap_malloc(sizeof(float) * (N / 2));
  265. wc_free = 1;
  266. if (!Wc) {
  267. if (r_free)
  268. luat_heap_free(r);
  269. if (im_free)
  270. luat_heap_free(im);
  271. return luaL_error(L, "no memory");
  272. }
  273. if (read_lua_array_float(L, 4, Wc, N / 2)) {
  274. if (r_free)
  275. luat_heap_free(r);
  276. if (im_free)
  277. luat_heap_free(im);
  278. if (wc_free)
  279. luat_heap_free(Wc);
  280. return luaL_error(L, "W_real must be number array or zbuff");
  281. }
  282. }
  283. // W_imag
  284. zb = (luat_zbuff_t*)luaL_testudata(L, 5, LUAT_ZBUFF_TYPE);
  285. if (zb) {
  286. Ws = (float*)zb->addr;
  287. } else {
  288. Ws = luat_heap_malloc(sizeof(float) * (N / 2));
  289. ws_free = 1;
  290. if (!Ws) {
  291. if (r_free)
  292. luat_heap_free(r);
  293. if (im_free)
  294. luat_heap_free(im);
  295. if (wc_free)
  296. luat_heap_free(Wc);
  297. return luaL_error(L, "内存不足");
  298. }
  299. if (read_lua_array_float(L, 5, Ws, N / 2)) {
  300. if (r_free)
  301. luat_heap_free(r);
  302. if (im_free)
  303. luat_heap_free(im);
  304. if (wc_free)
  305. luat_heap_free(Wc);
  306. if (ws_free)
  307. luat_heap_free(Ws);
  308. return luaL_error(L, "W_imag 需为数字数组或 zbuff");
  309. }
  310. }
  311. // 读取第6个参数的 opts(若有)
  312. if (lua_gettop(L) >= 6 && lua_istable(L, 6)) {
  313. lua_getfield(L, 6, "core");
  314. if (!lua_isnil(L, -1))
  315. core = luaL_checkstring(L, -1);
  316. lua_pop(L, 1);
  317. lua_getfield(L, 6, "input_format");
  318. if (!lua_isnil(L, -1))
  319. input_format = luaL_checkstring(L, -1);
  320. lua_pop(L, 1);
  321. }
  322. // 如果选择 q15 内核,且输入为整数 zbuff,则走 q15 路径
  323. int use_q15 = (core && strcmp(core, "q15") == 0);
  324. int integer_input = (strcmp(input_format, "u12") == 0 || strcmp(input_format, "u16") == 0 || strcmp(input_format, "s16") == 0);
  325. if (use_q15 && integer_input) {
  326. // 校验 real/imag 是否为 zbuff(当前整型快速路径仅支持 zbuff)
  327. luat_zbuff_t* zb_real = (luat_zbuff_t*)luaL_testudata(L, 1, LUAT_ZBUFF_TYPE);
  328. luat_zbuff_t* zb_imag = (luat_zbuff_t*)luaL_testudata(L, 2, LUAT_ZBUFF_TYPE);
  329. if (!zb_real) {
  330. if (r_free)
  331. luat_heap_free(r);
  332. if (im_free)
  333. luat_heap_free(im);
  334. if (wc_free)
  335. luat_heap_free(Wc);
  336. if (ws_free)
  337. luat_heap_free(Ws);
  338. return luaL_error(L, "q15 模式要求 real 为整数 zbuff");
  339. }
  340. if ((int)zb_real->len < N * 2) {
  341. if (r_free)
  342. luat_heap_free(r);
  343. if (im_free)
  344. luat_heap_free(im);
  345. if (wc_free)
  346. luat_heap_free(Wc);
  347. if (ws_free)
  348. luat_heap_free(Ws);
  349. return luaL_error(L, "real zbuff 太小");
  350. }
  351. // 原地:将整数输入就地转换为带符号 Q15(覆盖 zbuff)
  352. uint16_t* r16 = (uint16_t*)zb_real->addr;
  353. for (int i = 0; i < N; i++) {
  354. int32_t v;
  355. uint16_t u = r16[i];
  356. if (strcmp(input_format, "u12") == 0) {
  357. v = (int32_t)(u & 0x0FFF) - 2048;
  358. v <<= 4;
  359. } else if (strcmp(input_format, "u16") == 0) {
  360. v = (int32_t)u - 32768;
  361. } else {
  362. v = (int16_t)u;
  363. }
  364. if (v > 32767)
  365. v = 32767;
  366. if (v < -32768)
  367. v = -32768;
  368. ((int16_t*)zb_real->addr)[i] = (int16_t)v;
  369. }
  370. if (zb_imag && (int)zb_imag->len >= N * 2) {
  371. uint16_t* i16 = (uint16_t*)zb_imag->addr;
  372. for (int i = 0; i < N; i++) {
  373. int32_t v;
  374. uint16_t u = i16[i];
  375. if (strcmp(input_format, "u12") == 0) {
  376. v = (int32_t)(u & 0x0FFF) - 2048;
  377. v <<= 4;
  378. } else if (strcmp(input_format, "u16") == 0) {
  379. v = (int32_t)u - 32768;
  380. } else {
  381. v = (int16_t)u;
  382. }
  383. if (v > 32767)
  384. v = 32767;
  385. if (v < -32768)
  386. v = -32768;
  387. ((int16_t*)zb_imag->addr)[i] = (int16_t)v;
  388. }
  389. } else if (zb_imag) {
  390. // 长度不足则清零
  391. memset(zb_imag->addr, 0, zb_imag->len);
  392. }
  393. // 强制要求外部传入 Q15 twiddle
  394. luat_zbuff_t* zbWc = (luat_zbuff_t*)luaL_testudata(L, 4, LUAT_ZBUFF_TYPE);
  395. luat_zbuff_t* zbWs = (luat_zbuff_t*)luaL_testudata(L, 5, LUAT_ZBUFF_TYPE);
  396. const int need_tw = (N / 2) * 2;
  397. if (!(zbWc && zbWs) || (int)zbWc->len < need_tw || (int)zbWs->len < need_tw) {
  398. if (r_free)
  399. luat_heap_free(r);
  400. if (im_free)
  401. luat_heap_free(im);
  402. if (wc_free)
  403. luat_heap_free(Wc);
  404. if (ws_free)
  405. luat_heap_free(Ws);
  406. return luaL_error(L, "q15 需传入 Wc/Ws zbuff,长度 N/2*2 字节");
  407. }
  408. int scale_exp = 0;
  409. int rc = luat_fft_inplace_q15((int16_t*)zb_real->addr, zb_imag ? (int16_t*)zb_imag->addr : NULL,
  410. N, 0, (const int16_t*)zbWc->addr, (const int16_t*)zbWs->addr,
  411. 0, &scale_exp);
  412. if (r_free)
  413. luat_heap_free(r);
  414. if (im_free)
  415. luat_heap_free(im);
  416. if (wc_free)
  417. luat_heap_free(Wc);
  418. if (ws_free)
  419. luat_heap_free(Ws);
  420. if (rc != 0)
  421. return luaL_error(L, "q15 内核执行失败");
  422. return 0;
  423. }
  424. // 默认:沿用 float32 路径
  425. luat_fft_run_inplace(r, im, N, Wc, Ws);
  426. // if input was table, write back
  427. if (!luaL_testudata(L, 1, LUAT_ZBUFF_TYPE)) {
  428. lua_settop(L, 2);
  429. for (int i = 0; i < N; i++) {
  430. lua_pushnumber(L, r[i]);
  431. lua_rawseti(L, 1, i + 1);
  432. }
  433. }
  434. if (!luaL_testudata(L, 2, LUAT_ZBUFF_TYPE)) {
  435. for (int i = 0; i < N; i++) {
  436. lua_pushnumber(L, im[i]);
  437. lua_rawseti(L, 2, i + 1);
  438. }
  439. }
  440. if (r_free)
  441. luat_heap_free(r);
  442. if (im_free)
  443. luat_heap_free(im);
  444. if (wc_free)
  445. luat_heap_free(Wc);
  446. if (ws_free)
  447. luat_heap_free(Ws);
  448. return 0;
  449. }
  450. /*
  451. 原地 IFFT 计算
  452. @api fft.ifft(real, imag, N, Wc, Ws[, opts])
  453. @param real 实部容器,类型与 fft.run 一致
  454. @param imag 虚部容器,类型与 fft.run 一致
  455. @int N 点数,2 的幂
  456. @param Wc 旋转因子 cos:类型同 fft.run
  457. @param Ws 旋转因子 -sin:类型同 fft.run。建议为 IFFT 预共轭传入 +sin 表
  458. @table [opts]
  459. - core: "f32" | "q15"(默认 "f32")
  460. - input_format: "f32" | "u12" | "u16" | "s16"(q15 时必填其一)
  461. @return nil 就地修改 real/imag,并在 f32 路径下包含 1/N 归一化
  462. */
  463. static int l_fft_ifft(lua_State* L)
  464. {
  465. int N = luaL_checkinteger(L, 3);
  466. if (N <= 1 || (N & (N - 1)))
  467. return luaL_error(L, "N 必须为 2 的幂");
  468. float *r = NULL, *im = NULL, *Wc = NULL, *Ws = NULL;
  469. int r_free = 0, im_free = 0, wc_free = 0, ws_free = 0;
  470. // 可选 opts(同 run):当前仅 core/input_format
  471. const char* core = "f32";
  472. const char* input_format = "f32";
  473. luat_zbuff_t* zb = NULL;
  474. zb = (luat_zbuff_t*)luaL_testudata(L, 1, LUAT_ZBUFF_TYPE);
  475. if (zb) {
  476. r = (float*)zb->addr;
  477. } else {
  478. r = luat_heap_malloc(sizeof(float) * N);
  479. r_free = 1;
  480. if (!r)
  481. return luaL_error(L, "no memory");
  482. if (read_lua_array_float(L, 1, r, N)) {
  483. if (r_free)
  484. luat_heap_free(r);
  485. return luaL_error(L, "real must be number array or zbuff");
  486. }
  487. }
  488. zb = (luat_zbuff_t*)luaL_testudata(L, 2, LUAT_ZBUFF_TYPE);
  489. if (zb) {
  490. im = (float*)zb->addr;
  491. } else {
  492. im = luat_heap_malloc(sizeof(float) * N);
  493. im_free = 1;
  494. if (!im) {
  495. if (r_free)
  496. luat_heap_free(r);
  497. return luaL_error(L, "no memory");
  498. }
  499. if (read_lua_array_float(L, 2, im, N)) {
  500. if (r_free)
  501. luat_heap_free(r);
  502. if (im_free)
  503. luat_heap_free(im);
  504. return luaL_error(L, "imag must be number array or zbuff");
  505. }
  506. }
  507. zb = (luat_zbuff_t*)luaL_testudata(L, 4, LUAT_ZBUFF_TYPE);
  508. if (zb) {
  509. Wc = (float*)zb->addr;
  510. } else {
  511. Wc = luat_heap_malloc(sizeof(float) * (N / 2));
  512. wc_free = 1;
  513. if (!Wc) {
  514. if (r_free)
  515. luat_heap_free(r);
  516. if (im_free)
  517. luat_heap_free(im);
  518. return luaL_error(L, "no memory");
  519. }
  520. if (read_lua_array_float(L, 4, Wc, N / 2)) {
  521. if (r_free)
  522. luat_heap_free(r);
  523. if (im_free)
  524. luat_heap_free(im);
  525. if (wc_free)
  526. luat_heap_free(Wc);
  527. return luaL_error(L, "W_real must be number array or zbuff");
  528. }
  529. }
  530. zb = (luat_zbuff_t*)luaL_testudata(L, 5, LUAT_ZBUFF_TYPE);
  531. if (zb) {
  532. Ws = (float*)zb->addr;
  533. } else {
  534. Ws = luat_heap_malloc(sizeof(float) * (N / 2));
  535. ws_free = 1;
  536. if (!Ws) {
  537. if (r_free)
  538. luat_heap_free(r);
  539. if (im_free)
  540. luat_heap_free(im);
  541. if (wc_free)
  542. luat_heap_free(Wc);
  543. return luaL_error(L, "内存不足");
  544. }
  545. if (read_lua_array_float(L, 5, Ws, N / 2)) {
  546. if (r_free)
  547. luat_heap_free(r);
  548. if (im_free)
  549. luat_heap_free(im);
  550. if (wc_free)
  551. luat_heap_free(Wc);
  552. if (ws_free)
  553. luat_heap_free(Ws);
  554. return luaL_error(L, "W_imag 需为数字数组或 zbuff");
  555. }
  556. }
  557. if (lua_gettop(L) >= 6 && lua_istable(L, 6)) {
  558. lua_getfield(L, 6, "core");
  559. if (!lua_isnil(L, -1))
  560. core = luaL_checkstring(L, -1);
  561. lua_pop(L, 1);
  562. lua_getfield(L, 6, "input_format");
  563. if (!lua_isnil(L, -1))
  564. input_format = luaL_checkstring(L, -1);
  565. lua_pop(L, 1);
  566. }
  567. int use_q15 = (core && strcmp(core, "q15") == 0);
  568. int integer_input = (strcmp(input_format, "u12") == 0 || strcmp(input_format, "u16") == 0 || strcmp(input_format, "s16") == 0);
  569. if (use_q15 && integer_input) {
  570. luat_zbuff_t* zb_real = (luat_zbuff_t*)luaL_testudata(L, 1, LUAT_ZBUFF_TYPE);
  571. luat_zbuff_t* zb_imag = (luat_zbuff_t*)luaL_testudata(L, 2, LUAT_ZBUFF_TYPE);
  572. if (!zb_real) {
  573. if (r_free)
  574. luat_heap_free(r);
  575. if (im_free)
  576. luat_heap_free(im);
  577. if (wc_free)
  578. luat_heap_free(Wc);
  579. if (ws_free)
  580. luat_heap_free(Ws);
  581. return luaL_error(L, "q15 模式要求 real 为整数 zbuff");
  582. }
  583. if ((int)zb_real->len < N * 2) {
  584. if (r_free)
  585. luat_heap_free(r);
  586. if (im_free)
  587. luat_heap_free(im);
  588. if (wc_free)
  589. luat_heap_free(Wc);
  590. if (ws_free)
  591. luat_heap_free(Ws);
  592. return luaL_error(L, "real zbuff 太小");
  593. }
  594. int16_t* rq = (int16_t*)luat_heap_malloc(sizeof(int16_t) * N);
  595. int16_t* iq = (int16_t*)luat_heap_malloc(sizeof(int16_t) * N);
  596. if (!rq || !iq) {
  597. if (rq)
  598. luat_heap_free(rq);
  599. if (iq)
  600. luat_heap_free(iq);
  601. if (r_free)
  602. luat_heap_free(r);
  603. if (im_free)
  604. luat_heap_free(im);
  605. if (wc_free)
  606. luat_heap_free(Wc);
  607. if (ws_free)
  608. luat_heap_free(Ws);
  609. return luaL_error(L, "内存不足");
  610. }
  611. const uint16_t* r16 = (const uint16_t*)zb_real->addr;
  612. for (int i = 0; i < N; i++) {
  613. int32_t v;
  614. uint16_t u = r16[i];
  615. if (strcmp(input_format, "u12") == 0) {
  616. v = ((int32_t)(u & 0x0FFF)) - 2048;
  617. v <<= 4;
  618. } else if (strcmp(input_format, "u16") == 0) {
  619. v = (int32_t)u - 32768;
  620. } else {
  621. v = (int16_t)u;
  622. }
  623. if (v > 32767)
  624. v = 32767;
  625. if (v < -32768)
  626. v = -32768;
  627. rq[i] = (int16_t)v;
  628. }
  629. if (zb_imag && (int)zb_imag->len >= N * 2) {
  630. const uint16_t* i16 = (const uint16_t*)zb_imag->addr;
  631. for (int i = 0; i < N; i++) {
  632. int32_t v;
  633. uint16_t u = i16[i];
  634. if (strcmp(input_format, "u12") == 0) {
  635. v = ((int32_t)(u & 0x0FFF)) - 2048;
  636. v <<= 4;
  637. } else if (strcmp(input_format, "u16") == 0) {
  638. v = (int32_t)u - 32768;
  639. } else {
  640. v = (int16_t)u;
  641. }
  642. if (v > 32767)
  643. v = 32767;
  644. if (v < -32768)
  645. v = -32768;
  646. iq[i] = (int16_t)v;
  647. }
  648. } else {
  649. memset(iq, 0, sizeof(int16_t) * N);
  650. }
  651. // 使用传入的 Q15 旋转因子(zbuff)或按需生成
  652. luat_zbuff_t* zbWc = (luat_zbuff_t*)luaL_testudata(L, 4, LUAT_ZBUFF_TYPE);
  653. luat_zbuff_t* zbWs = (luat_zbuff_t*)luaL_testudata(L, 5, LUAT_ZBUFF_TYPE);
  654. const int need_tw = (N / 2) * 2;
  655. const int16_t* Wcq = NULL;
  656. const int16_t* Wsq = NULL;
  657. int16_t* Wcq_alloc = NULL;
  658. int16_t* Wsq_alloc = NULL;
  659. if (zbWc && zbWs && (int)zbWc->len >= need_tw && (int)zbWs->len >= need_tw) {
  660. Wcq = (const int16_t*)zbWc->addr;
  661. Wsq = (const int16_t*)zbWs->addr;
  662. } else {
  663. Wcq_alloc = (int16_t*)luat_heap_malloc(sizeof(int16_t) * (N / 2));
  664. Wsq_alloc = (int16_t*)luat_heap_malloc(sizeof(int16_t) * (N / 2));
  665. if (!Wcq_alloc || !Wsq_alloc) {
  666. if (Wcq_alloc)
  667. luat_heap_free(Wcq_alloc);
  668. if (Wsq_alloc)
  669. luat_heap_free(Wsq_alloc);
  670. luat_heap_free(rq);
  671. luat_heap_free(iq);
  672. if (r_free)
  673. luat_heap_free(r);
  674. if (im_free)
  675. luat_heap_free(im);
  676. if (wc_free)
  677. luat_heap_free(Wc);
  678. if (ws_free)
  679. luat_heap_free(Ws);
  680. return luaL_error(L, "内存不足");
  681. }
  682. luat_fft_generate_twiddles_q15(Wcq_alloc, Wsq_alloc, N);
  683. Wcq = Wcq_alloc;
  684. Wsq = Wsq_alloc;
  685. }
  686. int scale_exp = 0;
  687. int rc = luat_fft_inplace_q15(rq, iq, N, 1, Wcq, Wsq, 0, &scale_exp); // inverse=1
  688. if (Wcq_alloc)
  689. luat_heap_free(Wcq_alloc);
  690. if (Wsq_alloc)
  691. luat_heap_free(Wsq_alloc);
  692. if (rc != 0) {
  693. luat_heap_free(rq);
  694. luat_heap_free(iq);
  695. if (r_free)
  696. luat_heap_free(r);
  697. if (im_free)
  698. luat_heap_free(im);
  699. if (wc_free)
  700. luat_heap_free(Wc);
  701. if (ws_free)
  702. luat_heap_free(Ws);
  703. return luaL_error(L, "q15 内核执行失败");
  704. }
  705. for (int i = 0; i < N; i++)
  706. ((int16_t*)zb_real->addr)[i] = rq[i];
  707. if (zb_imag && (int)zb_imag->len >= N * 2) {
  708. for (int i = 0; i < N; i++)
  709. ((int16_t*)zb_imag->addr)[i] = iq[i];
  710. }
  711. luat_heap_free(rq);
  712. luat_heap_free(iq);
  713. if (r_free)
  714. luat_heap_free(r);
  715. if (im_free)
  716. luat_heap_free(im);
  717. if (wc_free)
  718. luat_heap_free(Wc);
  719. if (ws_free)
  720. luat_heap_free(Ws);
  721. return 0;
  722. }
  723. luat_ifft_run_inplace(r, im, N, Wc, Ws);
  724. if (!luaL_testudata(L, 1, LUAT_ZBUFF_TYPE)) {
  725. lua_settop(L, 2);
  726. for (int i = 0; i < N; i++) {
  727. lua_pushnumber(L, r[i]);
  728. lua_rawseti(L, 1, i + 1);
  729. }
  730. }
  731. if (!luaL_testudata(L, 2, LUAT_ZBUFF_TYPE)) {
  732. for (int i = 0; i < N; i++) {
  733. lua_pushnumber(L, im[i]);
  734. lua_rawseti(L, 2, i + 1);
  735. }
  736. }
  737. if (r_free)
  738. luat_heap_free(r);
  739. if (im_free)
  740. luat_heap_free(im);
  741. if (wc_free)
  742. luat_heap_free(Wc);
  743. if (ws_free)
  744. luat_heap_free(Ws);
  745. return 0;
  746. }
  747. /*
  748. 频域积分(1/(jω))
  749. @api fft.fft_integral(real, imag, n, df)
  750. @param real 实部(float32,Lua 数组或 zbuff)
  751. @param imag 虚部(float32,Lua 数组或 zbuff)
  752. @int n 点数,2 的幂
  753. @number df 频率分辨率(fs/n)
  754. @return nil 原地修改 real/imag(DC 置 0)
  755. @usage
  756. -- 先完成 FFT 得到谱 (real, imag),再调用积分:
  757. fft.fft_integral(real, imag, N, fs/N)
  758. */
  759. static int l_fft_integral(lua_State* L)
  760. {
  761. int n = luaL_checkinteger(L, 3);
  762. float df = (float)luaL_checknumber(L, 4);
  763. if (n <= 1 || (n & (n - 1)))
  764. return luaL_error(L, "n must be power of 2");
  765. float *r = NULL, *im = NULL;
  766. int r_free = 0, im_free = 0;
  767. luat_zbuff_t* zb = NULL;
  768. zb = (luat_zbuff_t*)luaL_testudata(L, 1, LUAT_ZBUFF_TYPE);
  769. if (zb) {
  770. r = (float*)zb->addr;
  771. } else {
  772. r = luat_heap_malloc(sizeof(float) * n);
  773. r_free = 1;
  774. if (!r)
  775. return luaL_error(L, "no memory");
  776. if (read_lua_array_float(L, 1, r, n)) {
  777. if (r_free)
  778. luat_heap_free(r);
  779. return luaL_error(L, "real must be number array or zbuff");
  780. }
  781. }
  782. zb = (luat_zbuff_t*)luaL_testudata(L, 2, LUAT_ZBUFF_TYPE);
  783. if (zb) {
  784. im = (float*)zb->addr;
  785. } else {
  786. im = luat_heap_malloc(sizeof(float) * n);
  787. im_free = 1;
  788. if (!im) {
  789. if (r_free)
  790. luat_heap_free(r);
  791. return luaL_error(L, "no memory");
  792. }
  793. if (read_lua_array_float(L, 2, im, n)) {
  794. if (r_free)
  795. luat_heap_free(r);
  796. if (im_free)
  797. luat_heap_free(im);
  798. return luaL_error(L, "imag must be number array or zbuff");
  799. }
  800. }
  801. luat_fft_integral_inplace(r, im, n, df);
  802. if (!luaL_testudata(L, 1, LUAT_ZBUFF_TYPE)) {
  803. lua_settop(L, 2);
  804. for (int i = 0; i < n; i++) {
  805. lua_pushnumber(L, r[i]);
  806. lua_rawseti(L, 1, i + 1);
  807. }
  808. }
  809. if (!luaL_testudata(L, 2, LUAT_ZBUFF_TYPE)) {
  810. for (int i = 0; i < n; i++) {
  811. lua_pushnumber(L, im[i]);
  812. lua_rawseti(L, 2, i + 1);
  813. }
  814. }
  815. if (r_free)
  816. luat_heap_free(r);
  817. if (im_free)
  818. luat_heap_free(im);
  819. return 0;
  820. }
  821. static const rotable_Reg_t reg_fft[] = {
  822. { "generate_twiddles", ROREG_FUNC(l_fft_generate_twiddles) },
  823. { "generate_twiddles_q15_to_zbuff", ROREG_FUNC(l_fft_generate_twiddles_q15_to_zbuff) },
  824. { "run", ROREG_FUNC(l_fft_run) },
  825. { "ifft", ROREG_FUNC(l_fft_ifft) },
  826. { "fft_integral", ROREG_FUNC(l_fft_integral) },
  827. { NULL, ROREG_INT(0) }
  828. };
  829. LUAMOD_API int luaopen_fft(lua_State* L)
  830. {
  831. luat_newlib2(L, reg_fft);
  832. return 1;
  833. }