From dcf77b40d6330ba86c76f4ecbabe26b5cdaefd71 Mon Sep 17 00:00:00 2001 From: Guilhem Lavaux Date: Tue, 15 Jul 2014 15:49:11 +0200 Subject: [PATCH] New fftw module --- python/_project.pyx | 38 ++++++++++++++++-------------------- python/cosmotool/__init__.py | 15 +++++++++----- python/cosmotool/fftw.py | 36 ++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 26 deletions(-) create mode 100644 python/cosmotool/fftw.py diff --git a/python/_project.pyx b/python/_project.pyx index 2de1834..317f001 100644 --- a/python/_project.pyx +++ b/python/_project.pyx @@ -45,13 +45,6 @@ cdef int interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, ry -= iy rz -= iz - while ix < 0: - ix += Ngrid - while iy < 0: - iy += Ngrid - while iz < 0: - iz += Ngrid - jx = (ix+1)%Ngrid jy = (iy+1)%Ngrid jz = (iz+1)%Ngrid @@ -194,12 +187,12 @@ def interp3d(x not None, y not None, Nelt = ax.size with nogil: if myperiodic: - for i in xrange(Nelt): + for i in prange(Nelt): if interp3d_INTERNAL_periodic(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: with gil: raise ierror else: - for i in xrange(Nelt): + for i in prange(Nelt): if interp3d_INTERNAL(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: with gil: raise ierror @@ -213,6 +206,7 @@ def interp3d(x not None, y not None, if interp3d_INTERNAL(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0: raise ierror return retval + @cython.boundscheck(False) @cython.cdivision(True) cdef DTYPE_t interp2d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, @@ -376,12 +370,17 @@ cdef void INTERNAL_project_cic_no_mass_periodic(npx.ndarray[DTYPE_t, ndim=3] g, cdef double a[3], c[3] cdef int b[3], b1[3] cdef int do_not_put + cdef DTYPE_t[:,:] ax + cdef DTYPE_t[:,:,:] ag + + ax = x + ag = g for i in range(x.shape[0]): do_not_put = 0 for j in range(3): - a[j] = (x[i,j]+half_Box)*delta_Box + a[j] = (ax[i,j]+half_Box)*delta_Box b[j] = int(floor(a[j])) b1[j] = (b[j]+1) % Ngrid @@ -390,18 +389,15 @@ cdef void INTERNAL_project_cic_no_mass_periodic(npx.ndarray[DTYPE_t, ndim=3] g, b[j] %= Ngrid - assert b[j] >= 0 and b[j] < Ngrid - assert b1[j] >= 0 and b1[j] < Ngrid - - g[b[0],b[1],b[2]] += c[0]*c[1]*c[2] - g[b1[0],b[1],b[2]] += a[0]*c[1]*c[2] - g[b[0],b1[1],b[2]] += c[0]*a[1]*c[2] - g[b1[0],b1[1],b[2]] += a[0]*a[1]*c[2] + ag[b[0],b[1],b[2]] += c[0]*c[1]*c[2] + ag[b1[0],b[1],b[2]] += a[0]*c[1]*c[2] + ag[b[0],b1[1],b[2]] += c[0]*a[1]*c[2] + ag[b1[0],b1[1],b[2]] += a[0]*a[1]*c[2] - g[b[0],b[1],b1[2]] += c[0]*c[1]*a[2] - g[b1[0],b[1],b1[2]] += a[0]*c[1]*a[2] - g[b[0],b1[1],b1[2]] += c[0]*a[1]*a[2] - g[b1[0],b1[1],b1[2]] += a[0]*a[1]*a[2] + ag[b[0],b[1],b1[2]] += c[0]*c[1]*a[2] + ag[b1[0],b[1],b1[2]] += a[0]*c[1]*a[2] + ag[b[0],b1[1],b1[2]] += c[0]*a[1]*a[2] + ag[b1[0],b1[1],b1[2]] += a[0]*a[1]*a[2] @cython.boundscheck(False) diff --git a/python/cosmotool/__init__.py b/python/cosmotool/__init__.py index 6c1d376..d53ceff 100644 --- a/python/cosmotool/__init__.py +++ b/python/cosmotool/__init__.py @@ -1,7 +1,12 @@ from _cosmotool import * from _project import * -from grafic import writeGrafic, writeWhitePhase, readGrafic, readWhitePhase -from borg import read_borg_vol -from cic import cicParticles -from simu import loadRamsesAll, simpleWriteGadget, SimulationBare -from timing import timeit, timeit_quiet +from .grafic import writeGrafic, writeWhitePhase, readGrafic, readWhitePhase +from .borg import read_borg_vol +from .cic import cicParticles +from .simu import loadRamsesAll, simpleWriteGadget, SimulationBare +from .timing import timeit, timeit_quiet + +try: + from .fftw import CubeFT +except ImportError: + print("No FFTW support") diff --git a/python/cosmotool/fftw.py b/python/cosmotool/fftw.py new file mode 100644 index 0000000..a640789 --- /dev/null +++ b/python/cosmotool/fftw.py @@ -0,0 +1,36 @@ +import pyfftw +import multiprocessing +import numpy as np +import numexpr as ne + +class CubeFT(object): + def __init__(self, L, N, max_cpu=-1): + + self.N = N + self.align = pyfftw.simd_alignment + self.L = L + self.max_cpu = multiprocessing.cpu_count() if max_cpu < 0 else max_cpu + self._dhat = pyfftw.n_byte_align_empty((self.N,self.N,self.N/2+1), self.align, dtype='complex64') + self._density = pyfftw.n_byte_align_empty((self.N,self.N,self.N), self.align, dtype='float32') + self._irfft = pyfftw.FFTW(self._dhat, self._density, axes=(0,1,2), direction='FFTW_BACKWARD', threads=self.max_cpu, normalize_idft=False) + self._rfft = pyfftw.FFTW(self._density, self._dhat, axes=(0,1,2), threads=self.max_cpu, normalize_idft=False) + + def rfft(self): + return ne.evaluate('c*a', out=self._dhat, local_dict={'c':self._rfft(normalise_idft=False),'a':(self.L/self.N)**3}, casting='unsafe') + + def irfft(self): + return ne.evaluate('c*a', out=self._density, local_dict={'c':self._irfft(normalise_idft=False),'a':(1/self.L)**3}, casting='unsafe') + + def get_dhat(self): + return self._dhat + def set_dhat(self, in_dhat): + self._dhat[:] = in_dhat + dhat = property(get_dhat, set_dhat, None) + + def get_density(self): + return self._density + def set_density(self, d): + self._density[:] = d + density = property(get_density, set_density, None) + +