Fix PYFFTW usage. Use numexpr

This commit is contained in:
Guilhem Lavaux 2014-07-04 16:53:17 +02:00
parent 815c0b616a
commit de1bda366a
3 changed files with 15 additions and 12 deletions

View file

@ -1,3 +1,4 @@
import numexpr as ne
import multiprocessing
import pyfftw
import weakref
@ -13,14 +14,14 @@ class CubeFT(object):
self.max_cpu = multiprocessing.cpu_count() if max_cpu < 0 else max_cpu
self._dhat = pyfftw.n_byte_align_empty((self.N,self.N,self.N/2+1), self.align, dtype='complex64')
self._density = pyfftw.n_byte_align_empty((self.N,self.N,self.N), self.align, dtype='float32')
self.irfft = pyfftw.FFTW(self._dhat, self._density, axes=(0,1,2), direction='FFTW_BACKWARD', threads=self.max_cpu, normalize_idft=False)
self.rfft = pyfftw.FFTW(self._density, self._dhat, axes=(0,1,2), threads=self.max_cpu, normalize_idft=False)
self._irfft = pyfftw.FFTW(self._dhat, self._density, axes=(0,1,2), direction='FFTW_BACKWARD', threads=self.max_cpu, normalize_idft=False)
self._rfft = pyfftw.FFTW(self._density, self._dhat, axes=(0,1,2), threads=self.max_cpu, normalize_idft=False)
def rfft(self):
return self.rfft()*(self.L/self.N)**3
return ne.evaluate('c*a', local_dict={'c':self._rfft(normalise_idft=False),'a':(self.L/self.N)**3})
def irfft(self):
return self.irfft()/self.L**3
return ne.evaluate('c*a', local_dict={'c':self._irfft(normalise_idft=False),'a':(1/self.L)**3})
def get_dhat(self):
return self._dhat
@ -152,16 +153,18 @@ class LagrangianPerturbation(object):
k2 = self._get_k2()
k2[0,0,0] = 1
potgen0 = lambda i: ne.evaluate('kdir**2*d/k2',local_dict={'kdir':self._kdir(i),'d':self.dhat,'k2':k2} )
potgen = lambda i,j: ne.evaluate('kdir0*kdir1*d/k2',local_dict={'kdir0':self._kdir(i),'kdir1':self._kdir(j),'d':self.dhat,'k2':k2} )
if 'lpt2_potential' not in self.cache:
print("Rebuilding potential...")
div_phi2 = np.zeros((self.N,self.N,self.N), dtype=np.float64)
for j in xrange(3):
q = self._do_irfft( self._kdir(j)**2*self.dhat / k2 ).copy()
q = self._do_irfft( potgen0(j) ).copy()
for i in xrange(j+1, 3):
div_phi2 += q * self._do_irfft( self._kdir(i)**2*self.dhat / k2 )
div_phi2 -= (self._do_irfft(self._kdir(j)*self._kdir(i)*self.dhat / k2 ) )**2
div_phi2 += q * self._do_irfft( potgen0(i) )
div_phi2 -= self._do_irfft(potgen(i,j))**2
div_phi2 *= 1/self.L**6
phi2_hat = -self._do_rfft(div_phi2) / k2
#self.cache['lpt2_potential'] = phi2_hat
del div_phi2