New fftw module

This commit is contained in:
Guilhem Lavaux 2014-07-15 15:49:11 +02:00
parent 6150597c67
commit dcf77b40d6
3 changed files with 63 additions and 26 deletions

View File

@ -45,13 +45,6 @@ cdef int interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
ry -= iy ry -= iy
rz -= iz rz -= iz
while ix < 0:
ix += Ngrid
while iy < 0:
iy += Ngrid
while iz < 0:
iz += Ngrid
jx = (ix+1)%Ngrid jx = (ix+1)%Ngrid
jy = (iy+1)%Ngrid jy = (iy+1)%Ngrid
jz = (iz+1)%Ngrid jz = (iz+1)%Ngrid
@ -194,12 +187,12 @@ def interp3d(x not None, y not None,
Nelt = ax.size Nelt = ax.size
with nogil: with nogil:
if myperiodic: if myperiodic:
for i in xrange(Nelt): for i in prange(Nelt):
if interp3d_INTERNAL_periodic(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: if interp3d_INTERNAL_periodic(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0:
with gil: with gil:
raise ierror raise ierror
else: else:
for i in xrange(Nelt): for i in prange(Nelt):
if interp3d_INTERNAL(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: if interp3d_INTERNAL(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0:
with gil: with gil:
raise ierror raise ierror
@ -213,6 +206,7 @@ def interp3d(x not None, y not None,
if interp3d_INTERNAL(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0: if interp3d_INTERNAL(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0:
raise ierror raise ierror
return retval return retval
@cython.boundscheck(False) @cython.boundscheck(False)
@cython.cdivision(True) @cython.cdivision(True)
cdef DTYPE_t interp2d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, cdef DTYPE_t interp2d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
@ -376,12 +370,17 @@ cdef void INTERNAL_project_cic_no_mass_periodic(npx.ndarray[DTYPE_t, ndim=3] g,
cdef double a[3], c[3] cdef double a[3], c[3]
cdef int b[3], b1[3] cdef int b[3], b1[3]
cdef int do_not_put cdef int do_not_put
cdef DTYPE_t[:,:] ax
cdef DTYPE_t[:,:,:] ag
ax = x
ag = g
for i in range(x.shape[0]): for i in range(x.shape[0]):
do_not_put = 0 do_not_put = 0
for j in range(3): for j in range(3):
a[j] = (x[i,j]+half_Box)*delta_Box a[j] = (ax[i,j]+half_Box)*delta_Box
b[j] = int(floor(a[j])) b[j] = int(floor(a[j]))
b1[j] = (b[j]+1) % Ngrid b1[j] = (b[j]+1) % Ngrid
@ -390,18 +389,15 @@ cdef void INTERNAL_project_cic_no_mass_periodic(npx.ndarray[DTYPE_t, ndim=3] g,
b[j] %= Ngrid b[j] %= Ngrid
assert b[j] >= 0 and b[j] < Ngrid ag[b[0],b[1],b[2]] += c[0]*c[1]*c[2]
assert b1[j] >= 0 and b1[j] < Ngrid ag[b1[0],b[1],b[2]] += a[0]*c[1]*c[2]
ag[b[0],b1[1],b[2]] += c[0]*a[1]*c[2]
ag[b1[0],b1[1],b[2]] += a[0]*a[1]*c[2]
g[b[0],b[1],b[2]] += c[0]*c[1]*c[2] ag[b[0],b[1],b1[2]] += c[0]*c[1]*a[2]
g[b1[0],b[1],b[2]] += a[0]*c[1]*c[2] ag[b1[0],b[1],b1[2]] += a[0]*c[1]*a[2]
g[b[0],b1[1],b[2]] += c[0]*a[1]*c[2] ag[b[0],b1[1],b1[2]] += c[0]*a[1]*a[2]
g[b1[0],b1[1],b[2]] += a[0]*a[1]*c[2] ag[b1[0],b1[1],b1[2]] += a[0]*a[1]*a[2]
g[b[0],b[1],b1[2]] += c[0]*c[1]*a[2]
g[b1[0],b[1],b1[2]] += a[0]*c[1]*a[2]
g[b[0],b1[1],b1[2]] += c[0]*a[1]*a[2]
g[b1[0],b1[1],b1[2]] += a[0]*a[1]*a[2]
@cython.boundscheck(False) @cython.boundscheck(False)

View File

@ -1,7 +1,12 @@
from _cosmotool import * from _cosmotool import *
from _project import * from _project import *
from grafic import writeGrafic, writeWhitePhase, readGrafic, readWhitePhase from .grafic import writeGrafic, writeWhitePhase, readGrafic, readWhitePhase
from borg import read_borg_vol from .borg import read_borg_vol
from cic import cicParticles from .cic import cicParticles
from simu import loadRamsesAll, simpleWriteGadget, SimulationBare from .simu import loadRamsesAll, simpleWriteGadget, SimulationBare
from timing import timeit, timeit_quiet from .timing import timeit, timeit_quiet
try:
from .fftw import CubeFT
except ImportError:
print("No FFTW support")

36
python/cosmotool/fftw.py Normal file
View File

@ -0,0 +1,36 @@
import pyfftw
import multiprocessing
import numpy as np
import numexpr as ne
class CubeFT(object):
def __init__(self, L, N, max_cpu=-1):
self.N = N
self.align = pyfftw.simd_alignment
self.L = L
self.max_cpu = multiprocessing.cpu_count() if max_cpu < 0 else max_cpu
self._dhat = pyfftw.n_byte_align_empty((self.N,self.N,self.N/2+1), self.align, dtype='complex64')
self._density = pyfftw.n_byte_align_empty((self.N,self.N,self.N), self.align, dtype='float32')
self._irfft = pyfftw.FFTW(self._dhat, self._density, axes=(0,1,2), direction='FFTW_BACKWARD', threads=self.max_cpu, normalize_idft=False)
self._rfft = pyfftw.FFTW(self._density, self._dhat, axes=(0,1,2), threads=self.max_cpu, normalize_idft=False)
def rfft(self):
return ne.evaluate('c*a', out=self._dhat, local_dict={'c':self._rfft(normalise_idft=False),'a':(self.L/self.N)**3}, casting='unsafe')
def irfft(self):
return ne.evaluate('c*a', out=self._density, local_dict={'c':self._irfft(normalise_idft=False),'a':(1/self.L)**3}, casting='unsafe')
def get_dhat(self):
return self._dhat
def set_dhat(self, in_dhat):
self._dhat[:] = in_dhat
dhat = property(get_dhat, set_dhat, None)
def get_density(self):
return self._density
def set_density(self, d):
self._density[:] = d
density = property(get_density, set_density, None)