Use more numexpr and pyfftw

This commit is contained in:
Guilhem Lavaux 2014-09-19 11:32:14 +02:00
parent 025cdba6a9
commit 7c346fa1b2
2 changed files with 32 additions and 21 deletions

View file

@ -93,6 +93,9 @@ class LagrangianPerturbation(object):
if supersample > 1:
self.upgrade_sampling(supersample)
self.ik = np.fft.fftfreq(self.N, d=L/self.N)*2*np.pi
self._kx = self.ik[:,None,None]
self._ky = self.ik[None,:,None]
self._kz = self.ik[None,None,:(self.N/2+1)]
self.cache = {}#weakref.WeakValueDictionary()
def upgrade_sampling(self, supersample):
@ -111,14 +114,15 @@ class LagrangianPerturbation(object):
self.cube = CubeFT(self.L, self.N, max_cpu=self.max_cpu)
def _gradient(self, phi, direction):
self.cube.dhat = self._kdir(direction)*1j*phi
ne.evaluate('phi_hat * i * kv / (kx**2 + ky**2 + kz**2)', out=self.cube.dhat,
local_dict={'i':-1j, 'phi_hat':phi, 'kv':self._kdir(direction),
'kx':self._kx, 'ky':self._ky, 'kz':self._kz}
# self.cube.dhat = self._kdir(direction)*1j*phi
self.cube.dhat[0,0,0] = 0
return self.cube.irfft()
def lpt1(self, direction=0):
k2 = self._get_k2()
k2[0,0,0] = 1
return self._gradient(self.dhat/k2, direction)
return self._gradient(self.dhat, direction)
def new_shape(self,direction, q=3, half=False):
N0 = (self.N/2+1) if half else self.N
@ -141,20 +145,22 @@ class LagrangianPerturbation(object):
self.cache['k2'] = k2
return k2
def _do_irfft(self, array):
self.cube.dhat = array
def _do_irfft(self, array, copy=True):
if copy:
self.cube.dhat = array
return self.cube.irfft()
def _do_rfft(self, array):
self.cube.density = array
def _do_rfft(self, array, copy=True):
if copy:
self.cube.density = array
return self.cube.rfft()
def lpt2(self, direction=0):
k2 = self._get_k2()
k2[0,0,0] = 1
# k2 = self._get_k2()
# k2[0,0,0] = 1
potgen0 = lambda i: ne.evaluate('kdir**2*d/k2',local_dict={'kdir':self._kdir(i),'d':self.dhat,'k2':k2} )
potgen = lambda i,j: ne.evaluate('kdir0*kdir1*d/k2',local_dict={'kdir0':self._kdir(i),'kdir1':self._kdir(j),'d':self.dhat,'k2':k2} )
potgen0 = lambda i: ne.evaluate('kdir**2*d/k2',out=self.cube.dhat,local_dict={'kdir':self._kdir(i),'d':self.dhat,'k2':k2} )
potgen = lambda i,j: ne.evaluate('kdir0*kdir1*d/k2',out=self.cube.dhat,local_dict={'kdir0':self._kdir(i),'kdir1':self._kdir(j),'d':self.dhat,'k2':k2} )
if 'lpt2_potential' not in self.cache:
print("Rebuilding potential...")
@ -162,10 +168,14 @@ class LagrangianPerturbation(object):
for j in xrange(3):
q = self._do_irfft( potgen0(j) ).copy()
for i in xrange(j+1, 3):
div_phi2 += q * self._do_irfft( potgen0(i) )
div_phi2 -= self._do_irfft(potgen(i,j))**2
ne.evaluate('div + q * pot', out=div_phi2,
local_dict={'q':q,'pot':self._do_irfft( potgen0(i), copy=False ) }
)
ne.evaluate('div - pot**2',out=div_phi2,
local_dict={'div':div_phi2,'pot':self._do_irfft(potgen(i,j), copy=False) }
)
phi2_hat = -self._do_rfft(div_phi2) / k2
phi2_hat = self._do_rfft(div_phi2)
#self.cache['lpt2_potential'] = phi2_hat
del div_phi2
else: