fastfir.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #!/usr/bin/env python
  2. from Numeric import *
  3. from FFT import *
  4. def make_random(len):
  5. import random
  6. res=[]
  7. for i in range(int(len)):
  8. r=random.uniform(-1,1)
  9. i=random.uniform(-1,1)
  10. res.append( complex(r,i) )
  11. return res
  12. def slowfilter(sig,h):
  13. translen = len(h)-1
  14. return convolve(sig,h)[translen:-translen]
  15. def nextpow2(x):
  16. return 2 ** math.ceil(math.log(x)/math.log(2))
  17. def fastfilter(sig,h,nfft=None):
  18. if nfft is None:
  19. nfft = int( nextpow2( 2*len(h) ) )
  20. H = fft( h , nfft )
  21. scraplen = len(h)-1
  22. keeplen = nfft-scraplen
  23. res=[]
  24. isdone = 0
  25. lastidx = nfft
  26. idx0 = 0
  27. while not isdone:
  28. idx1 = idx0 + nfft
  29. if idx1 >= len(sig):
  30. idx1 = len(sig)
  31. lastidx = idx1-idx0
  32. if lastidx <= scraplen:
  33. break
  34. isdone = 1
  35. Fss = fft(sig[idx0:idx1],nfft)
  36. fm = Fss * H
  37. m = inverse_fft(fm)
  38. res.append( m[scraplen:lastidx] )
  39. idx0 += keeplen
  40. return concatenate( res )
  41. def main():
  42. import sys
  43. from getopt import getopt
  44. opts,args = getopt(sys.argv[1:],'rn:l:')
  45. opts=dict(opts)
  46. siglen = int(opts.get('-l',1e4 ) )
  47. hlen =50
  48. nfft = int(opts.get('-n',128) )
  49. usereal = opts.has_key('-r')
  50. print 'nfft=%d'%nfft
  51. # make a signal
  52. sig = make_random( siglen )
  53. # make an impulse response
  54. h = make_random( hlen )
  55. #h=[1]*2+[0]*3
  56. if usereal:
  57. sig=[c.real for c in sig]
  58. h=[c.real for c in h]
  59. # perform MAC filtering
  60. yslow = slowfilter(sig,h)
  61. #print '<YSLOW>',yslow,'</YSLOW>'
  62. #yfast = fastfilter(sig,h,nfft)
  63. yfast = utilfastfilter(sig,h,nfft,usereal)
  64. #print yfast
  65. print 'len(yslow)=%d'%len(yslow)
  66. print 'len(yfast)=%d'%len(yfast)
  67. diff = yslow-yfast
  68. snr = 10*log10( abs( vdot(yslow,yslow) / vdot(diff,diff) ) )
  69. print 'snr=%s' % snr
  70. if snr < 10.0:
  71. print 'h=',h
  72. print 'sig=',sig[:5],'...'
  73. print 'yslow=',yslow[:5],'...'
  74. print 'yfast=',yfast[:5],'...'
  75. def utilfastfilter(sig,h,nfft,usereal):
  76. import compfft
  77. import os
  78. open( 'sig.dat','w').write( compfft.dopack(sig,'f',not usereal) )
  79. open( 'h.dat','w').write( compfft.dopack(h,'f',not usereal) )
  80. if usereal:
  81. util = './fastconvr'
  82. else:
  83. util = './fastconv'
  84. cmd = 'time %s -n %d -i sig.dat -h h.dat -o out.dat' % (util, nfft)
  85. print cmd
  86. ec = os.system(cmd)
  87. print 'exited->',ec
  88. return compfft.dounpack(open('out.dat').read(),'f',not usereal)
  89. if __name__ == "__main__":
  90. main()