mk_test.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #!/usr/bin/env python
  2. import FFT
  3. import sys
  4. import random
  5. import re
  6. j=complex(0,1)
  7. def randvec(n,iscomplex):
  8. if iscomplex:
  9. return [
  10. int(random.uniform(-32768,32767) ) + j*int(random.uniform(-32768,32767) )
  11. for i in range(n) ]
  12. else:
  13. return [ int(random.uniform(-32768,32767) ) for i in range(n) ]
  14. def c_format(v,round=0):
  15. if round:
  16. return ','.join( [ '{%d,%d}' %(int(c.real),int(c.imag) ) for c in v ] )
  17. else:
  18. s= ','.join( [ '{%.60f ,%.60f }' %(c.real,c.imag) for c in v ] )
  19. return re.sub(r'\.?0+ ',' ',s)
  20. def test_cpx( n,inverse ,short):
  21. v = randvec(n,1)
  22. scale = 1
  23. if short:
  24. minsnr=30
  25. else:
  26. minsnr=100
  27. if inverse:
  28. tvecout = FFT.inverse_fft(v)
  29. if short:
  30. scale = 1
  31. else:
  32. scale = len(v)
  33. else:
  34. tvecout = FFT.fft(v)
  35. if short:
  36. scale = 1.0/len(v)
  37. tvecout = [ c * scale for c in tvecout ]
  38. s="""#define NFFT %d""" % len(v) + """
  39. {
  40. double snr;
  41. kiss_fft_cpx test_vec_in[NFFT] = { """ + c_format(v) + """};
  42. kiss_fft_cpx test_vec_out[NFFT] = {""" + c_format( tvecout ) + """};
  43. kiss_fft_cpx testbuf[NFFT];
  44. void * cfg = kiss_fft_alloc(NFFT,%d,0,0);""" % inverse + """
  45. kiss_fft(cfg,test_vec_in,testbuf);
  46. snr = snr_compare(test_vec_out,testbuf,NFFT);
  47. printf("DATATYPE=" xstr(kiss_fft_scalar) ", FFT n=%d, inverse=%d, snr = %g dB\\n",NFFT,""" + str(inverse) + """,snr);
  48. if (snr<""" + str(minsnr) + """)
  49. exit_code++;
  50. free(cfg);
  51. }
  52. #undef NFFT
  53. """
  54. return s
  55. def compare_func():
  56. s="""
  57. #define xstr(s) str(s)
  58. #define str(s) #s
  59. double snr_compare( kiss_fft_cpx * test_vec_out,kiss_fft_cpx * testbuf, int n)
  60. {
  61. int k;
  62. double sigpow,noisepow,err,snr,scale=0;
  63. kiss_fft_cpx err;
  64. sigpow = noisepow = .000000000000000000000000000001;
  65. for (k=0;k<n;++k) {
  66. sigpow += test_vec_out[k].r * test_vec_out[k].r +
  67. test_vec_out[k].i * test_vec_out[k].i;
  68. C_SUB(err,test_vec_out[k],testbuf[k].r);
  69. noisepow += err.r * err.r + err.i + err.i;
  70. if (test_vec_out[k].r)
  71. scale += testbuf[k].r / test_vec_out[k].r;
  72. }
  73. snr = 10*log10( sigpow / noisepow );
  74. scale /= n;
  75. if (snr<10)
  76. printf( "\\npoor snr, try a scaling factor %f\\n" , scale );
  77. return snr;
  78. }
  79. """
  80. return s
  81. def main():
  82. from getopt import getopt
  83. opts,args = getopt(sys.argv[1:],'s')
  84. opts = dict(opts)
  85. short = int( opts.has_key('-s') )
  86. fftsizes = args
  87. if not fftsizes:
  88. fftsizes = [ 1800 ]
  89. print '#include "kiss_fft.h"'
  90. print compare_func()
  91. print "int main() { int exit_code=0;\n"
  92. for n in fftsizes:
  93. n = int(n)
  94. print test_cpx(n,0,short)
  95. print test_cpx(n,1,short)
  96. print """
  97. return exit_code;
  98. }
  99. """
  100. if __name__ == "__main__":
  101. main()