This commit is contained in:
Guilhem Lavaux 2014-07-28 09:34:20 +02:00
commit 025cdba6a9
5 changed files with 185 additions and 60 deletions

View File

@ -1,7 +1,7 @@
from cpython cimport bool
from cython cimport view
from cython.parallel import prange, parallel
from libc.math cimport sin, cos, abs, floor, sqrt
from libc.math cimport sin, cos, abs, floor, round, sqrt
import numpy as np
cimport numpy as npx
cimport cython
@ -45,13 +45,6 @@ cdef int interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
ry -= iy
rz -= iz
while ix < 0:
ix += Ngrid
while iy < 0:
iy += Ngrid
while iz < 0:
iz += Ngrid
jx = (ix+1)%Ngrid
jy = (iy+1)%Ngrid
jz = (iz+1)%Ngrid
@ -89,6 +82,69 @@ cdef int interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
return 0
@cython.boundscheck(False)
@cython.cdivision(True)
cdef int ngp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
DTYPE_t z,
DTYPE_t[:,:,:] d, DTYPE_t Lbox, DTYPE_t *retval) nogil:
cdef int Ngrid = d.shape[0]
cdef DTYPE_t inv_delta = Ngrid/Lbox
cdef int ix, iy, iz
cdef DTYPE_t f[2][2][2]
cdef DTYPE_t rx, ry, rz
cdef int jx, jy, jz
rx = (inv_delta*x)
ry = (inv_delta*y)
rz = (inv_delta*z)
ix = int(round(rx))
iy = int(round(ry))
iz = int(round(rz))
ix = ix%Ngrid
iy = iy%Ngrid
iz = iz%Ngrid
retval[0] = d[ix ,iy ,iz ]
return 0
@cython.boundscheck(False)
@cython.cdivision(True)
cdef int ngp3d_INTERNAL(DTYPE_t x, DTYPE_t y,
DTYPE_t z,
DTYPE_t[:,:,:] d, DTYPE_t Lbox, DTYPE_t *retval) nogil:
cdef int Ngrid = d.shape[0]
cdef DTYPE_t inv_delta = Ngrid/Lbox
cdef int ix, iy, iz
cdef DTYPE_t f[2][2][2]
cdef DTYPE_t rx, ry, rz
cdef int jx, jy, jz
rx = (inv_delta*x)
ry = (inv_delta*y)
rz = (inv_delta*z)
ix = int(round(rx))
iy = int(round(ry))
iz = int(round(rz))
if (ix < 0 or ix >= Ngrid):
return -1
if (iy < 0 or iy >= Ngrid):
return -2
if (iz < 0 or iz >= Ngrid):
return -3
retval[0] = d[ix ,iy ,iz ]
return 0
@cython.boundscheck(False)
@cython.cdivision(True)
@ -152,8 +208,8 @@ cdef int interp3d_INTERNAL(DTYPE_t x, DTYPE_t y,
def interp3d(x not None, y not None,
z not None,
npx.ndarray[DTYPE_t, ndim=3] d not None, DTYPE_t Lbox,
bool periodic=False, bool centered=True):
""" interp3d(x,y,z,d,Lbox,periodic=False) -> interpolated values
bool periodic=False, bool centered=True, bool ngp=False):
""" interp3d(x,y,z,d,Lbox,periodic=False,centered=True,ngp=False) -> interpolated values
Compute the tri-linear interpolation of the given field (d) at the given position (x,y,z). It assumes that they are box-centered coordinates. So (x,y,z) == (0,0,0) is equivalent to the pixel at (Nx/2,Ny/2,Nz/2) with Nx,Ny,Nz = d.shape. If periodic is set, it assumes the box is periodic
"""
@ -164,10 +220,11 @@ def interp3d(x not None, y not None,
cdef DTYPE_t retval
cdef long i
cdef long Nelt
cdef int myperiodic
cdef int myperiodic, myngp
cdef DTYPE_t shifter
myperiodic = periodic
myngp = ngp
if centered:
shifter = Lbox/2
@ -193,26 +250,46 @@ def interp3d(x not None, y not None,
in_slice = d
Nelt = ax.size
with nogil:
if not myngp:
if myperiodic:
for i in xrange(Nelt):
for i in prange(Nelt):
if interp3d_INTERNAL_periodic(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0:
with gil:
raise ierror
else:
for i in xrange(Nelt):
for i in prange(Nelt):
if interp3d_INTERNAL(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0:
with gil:
raise ierror
else:
if myperiodic:
for i in prange(Nelt):
if ngp3d_INTERNAL_periodic(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0:
with gil:
raise ierror
else:
for i in prange(Nelt):
if ngp3d_INTERNAL(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0:
with gil:
raise ierror
return out
else:
if not myngp:
if periodic:
if interp3d_INTERNAL_periodic(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0:
raise ierror
else:
if interp3d_INTERNAL(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0:
raise ierror
else:
if periodic:
if ngp3d_INTERNAL_periodic(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0:
raise ierror
else:
if ngp3d_INTERNAL(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0:
raise ierror
return retval
@cython.boundscheck(False)
@cython.cdivision(True)
cdef DTYPE_t interp2d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
@ -376,12 +453,17 @@ cdef void INTERNAL_project_cic_no_mass_periodic(npx.ndarray[DTYPE_t, ndim=3] g,
cdef double a[3], c[3]
cdef int b[3], b1[3]
cdef int do_not_put
cdef DTYPE_t[:,:] ax
cdef DTYPE_t[:,:,:] ag
ax = x
ag = g
for i in range(x.shape[0]):
do_not_put = 0
for j in range(3):
a[j] = (x[i,j]+half_Box)*delta_Box
a[j] = (ax[i,j]+half_Box)*delta_Box
b[j] = int(floor(a[j]))
b1[j] = (b[j]+1) % Ngrid
@ -390,18 +472,15 @@ cdef void INTERNAL_project_cic_no_mass_periodic(npx.ndarray[DTYPE_t, ndim=3] g,
b[j] %= Ngrid
assert b[j] >= 0 and b[j] < Ngrid
assert b1[j] >= 0 and b1[j] < Ngrid
ag[b[0],b[1],b[2]] += c[0]*c[1]*c[2]
ag[b1[0],b[1],b[2]] += a[0]*c[1]*c[2]
ag[b[0],b1[1],b[2]] += c[0]*a[1]*c[2]
ag[b1[0],b1[1],b[2]] += a[0]*a[1]*c[2]
g[b[0],b[1],b[2]] += c[0]*c[1]*c[2]
g[b1[0],b[1],b[2]] += a[0]*c[1]*c[2]
g[b[0],b1[1],b[2]] += c[0]*a[1]*c[2]
g[b1[0],b1[1],b[2]] += a[0]*a[1]*c[2]
g[b[0],b[1],b1[2]] += c[0]*c[1]*a[2]
g[b1[0],b[1],b1[2]] += a[0]*c[1]*a[2]
g[b[0],b1[1],b1[2]] += c[0]*a[1]*a[2]
g[b1[0],b1[1],b1[2]] += a[0]*a[1]*a[2]
ag[b[0],b[1],b1[2]] += c[0]*c[1]*a[2]
ag[b1[0],b[1],b1[2]] += a[0]*c[1]*a[2]
ag[b[0],b1[1],b1[2]] += c[0]*a[1]*a[2]
ag[b1[0],b1[1],b1[2]] += a[0]*a[1]*a[2]
@cython.boundscheck(False)

View File

@ -1,7 +1,12 @@
from _cosmotool import *
from _project import *
from grafic import writeGrafic, writeWhitePhase, readGrafic, readWhitePhase
from borg import read_borg_vol
from cic import cicParticles
from simu import loadRamsesAll, simpleWriteGadget, SimulationBare
from timing import time_block, timeit, timeit_quiet
from .grafic import writeGrafic, writeWhitePhase, readGrafic, readWhitePhase
from .borg import read_borg_vol
from .cic import cicParticles
from .simu import loadRamsesAll, simpleWriteGadget, SimulationBare
from .timing import time_block, timeit, timeit_quiet
try:
from .fftw import CubeFT
except ImportError:
print("No FFTW support")

36
python/cosmotool/fftw.py Normal file
View File

@ -0,0 +1,36 @@
import pyfftw
import multiprocessing
import numpy as np
import numexpr as ne
class CubeFT(object):
def __init__(self, L, N, max_cpu=-1):
self.N = N
self.align = pyfftw.simd_alignment
self.L = L
self.max_cpu = multiprocessing.cpu_count() if max_cpu < 0 else max_cpu
self._dhat = pyfftw.n_byte_align_empty((self.N,self.N,self.N/2+1), self.align, dtype='complex64')
self._density = pyfftw.n_byte_align_empty((self.N,self.N,self.N), self.align, dtype='float32')
self._irfft = pyfftw.FFTW(self._dhat, self._density, axes=(0,1,2), direction='FFTW_BACKWARD', threads=self.max_cpu, normalize_idft=False)
self._rfft = pyfftw.FFTW(self._density, self._dhat, axes=(0,1,2), threads=self.max_cpu, normalize_idft=False)
def rfft(self):
return ne.evaluate('c*a', out=self._dhat, local_dict={'c':self._rfft(normalise_idft=False),'a':(self.L/self.N)**3}, casting='unsafe')
def irfft(self):
return ne.evaluate('c*a', out=self._density, local_dict={'c':self._irfft(normalise_idft=False),'a':(1/self.L)**3}, casting='unsafe')
def get_dhat(self):
return self._dhat
def set_dhat(self, in_dhat):
self._dhat[:] = in_dhat
dhat = property(get_dhat, set_dhat, None)
def get_density(self):
return self._density
def set_density(self, d):
self._density[:] = d
density = property(get_density, set_density, None)

View File

@ -103,7 +103,7 @@ def run_generation(input_borg, a_borg, a_ic, cosmo, supersample=1, do_lpt2=True,
@ct.timeit_quiet
def whitify(density, L, cosmo, supergenerate=1, func='HU_WIGGLES'):
def whitify(density, L, cosmo, supergenerate=1, zero_fill=False, func='HU_WIGGLES'):
N = density.shape[0]
p = ct.CosmologyPower(**cosmo)
@ -141,6 +141,10 @@ def whitify(density, L, cosmo, supergenerate=1, func='HU_WIGGLES'):
if supergenerate > 1:
cond=np.isnan(density_hat_super)
if zero_fill:
density_hat_super[cond] = 0
else:
print np.where(np.isnan(density_hat_super))[0].size
Nz = np.count_nonzero(cond)
density_hat_super.real[cond] = np.random.randn(Nz)
@ -166,8 +170,8 @@ def write_icfiles(*generated_ic, **kwargs):
"""Write the initial conditions from the tuple returned by run_generation"""
supergenerate=1
if 'supergenerate' in kwargs:
supergenerate=kwargs['supergenerate']
supergenerate=kwargs.get('supergenerate', 1)
zero_fill=kwargs.get('zero_fill', False)
posx,vel,density,N,L,a_ic,cosmo = generated_ic
ct.simpleWriteGadget("Data/borg.gad", posx, velocities=vel, boxsize=L, Hubble=cosmo['h'], Omega_M=cosmo['omega_M_0'], time=a_ic)
@ -176,7 +180,7 @@ def write_icfiles(*generated_ic, **kwargs):
ct.writeGrafic("Data/ic_deltab", density, L, a_ic, **cosmo)
ct.writeWhitePhase("Data/white.dat", whitify(density, L, cosmo, supergenerate=supergenerate))
ct.writeWhitePhase("Data/white.dat", whitify(density, L, cosmo, supergenerate=supergenerate,zero_fill=zero_fill))
with file("Data/white_params", mode="w") as f:
f.write("4\n%lg, %lg, %lg\n" % (cosmo['omega_M_0'], cosmo['omega_lambda_0'], 100*cosmo['h']))

View File

@ -9,10 +9,11 @@ cosmo['omega_B_0']=0.049
cosmo['SIGMA8']=0.8344
cosmo['ns']=0.9624
supergen=1
supergen=2
zstart=50
astart=1/(1.+zstart)
astart=1/69.#1/(1.+zstart)
halfPixelShift=False
zero_fill=True
if __name__=="__main__":
bic.write_icfiles(*bic.run_generation("initial_density_988.dat", 0.001, astart, cosmo, supersample=1, shiftPixel=halfPixelShift, do_lpt2=False), supergenerate=supergen)
bic.write_icfiles(*bic.run_generation("initial_density_1380.dat", 0.001, astart, cosmo, supersample=1, shiftPixel=halfPixelShift, do_lpt2=False), supergenerate=supergen, zero_fill=zero_fill)