This commit is contained in:
Guilhem Lavaux 2014-07-28 09:34:20 +02:00
commit 025cdba6a9
5 changed files with 185 additions and 60 deletions

View File

@ -1,7 +1,7 @@
from cpython cimport bool from cpython cimport bool
from cython cimport view from cython cimport view
from cython.parallel import prange, parallel from cython.parallel import prange, parallel
from libc.math cimport sin, cos, abs, floor, sqrt from libc.math cimport sin, cos, abs, floor, round, sqrt
import numpy as np import numpy as np
cimport numpy as npx cimport numpy as npx
cimport cython cimport cython
@ -45,13 +45,6 @@ cdef int interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
ry -= iy ry -= iy
rz -= iz rz -= iz
while ix < 0:
ix += Ngrid
while iy < 0:
iy += Ngrid
while iz < 0:
iz += Ngrid
jx = (ix+1)%Ngrid jx = (ix+1)%Ngrid
jy = (iy+1)%Ngrid jy = (iy+1)%Ngrid
jz = (iz+1)%Ngrid jz = (iz+1)%Ngrid
@ -89,6 +82,69 @@ cdef int interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
return 0 return 0
@cython.boundscheck(False)
@cython.cdivision(True)
cdef int ngp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
DTYPE_t z,
DTYPE_t[:,:,:] d, DTYPE_t Lbox, DTYPE_t *retval) nogil:
cdef int Ngrid = d.shape[0]
cdef DTYPE_t inv_delta = Ngrid/Lbox
cdef int ix, iy, iz
cdef DTYPE_t f[2][2][2]
cdef DTYPE_t rx, ry, rz
cdef int jx, jy, jz
rx = (inv_delta*x)
ry = (inv_delta*y)
rz = (inv_delta*z)
ix = int(round(rx))
iy = int(round(ry))
iz = int(round(rz))
ix = ix%Ngrid
iy = iy%Ngrid
iz = iz%Ngrid
retval[0] = d[ix ,iy ,iz ]
return 0
@cython.boundscheck(False)
@cython.cdivision(True)
cdef int ngp3d_INTERNAL(DTYPE_t x, DTYPE_t y,
DTYPE_t z,
DTYPE_t[:,:,:] d, DTYPE_t Lbox, DTYPE_t *retval) nogil:
cdef int Ngrid = d.shape[0]
cdef DTYPE_t inv_delta = Ngrid/Lbox
cdef int ix, iy, iz
cdef DTYPE_t f[2][2][2]
cdef DTYPE_t rx, ry, rz
cdef int jx, jy, jz
rx = (inv_delta*x)
ry = (inv_delta*y)
rz = (inv_delta*z)
ix = int(round(rx))
iy = int(round(ry))
iz = int(round(rz))
if (ix < 0 or ix >= Ngrid):
return -1
if (iy < 0 or iy >= Ngrid):
return -2
if (iz < 0 or iz >= Ngrid):
return -3
retval[0] = d[ix ,iy ,iz ]
return 0
@cython.boundscheck(False) @cython.boundscheck(False)
@cython.cdivision(True) @cython.cdivision(True)
@ -152,8 +208,8 @@ cdef int interp3d_INTERNAL(DTYPE_t x, DTYPE_t y,
def interp3d(x not None, y not None, def interp3d(x not None, y not None,
z not None, z not None,
npx.ndarray[DTYPE_t, ndim=3] d not None, DTYPE_t Lbox, npx.ndarray[DTYPE_t, ndim=3] d not None, DTYPE_t Lbox,
bool periodic=False, bool centered=True): bool periodic=False, bool centered=True, bool ngp=False):
""" interp3d(x,y,z,d,Lbox,periodic=False) -> interpolated values """ interp3d(x,y,z,d,Lbox,periodic=False,centered=True,ngp=False) -> interpolated values
Compute the tri-linear interpolation of the given field (d) at the given position (x,y,z). It assumes that they are box-centered coordinates. So (x,y,z) == (0,0,0) is equivalent to the pixel at (Nx/2,Ny/2,Nz/2) with Nx,Ny,Nz = d.shape. If periodic is set, it assumes the box is periodic Compute the tri-linear interpolation of the given field (d) at the given position (x,y,z). It assumes that they are box-centered coordinates. So (x,y,z) == (0,0,0) is equivalent to the pixel at (Nx/2,Ny/2,Nz/2) with Nx,Ny,Nz = d.shape. If periodic is set, it assumes the box is periodic
""" """
@ -164,10 +220,11 @@ def interp3d(x not None, y not None,
cdef DTYPE_t retval cdef DTYPE_t retval
cdef long i cdef long i
cdef long Nelt cdef long Nelt
cdef int myperiodic cdef int myperiodic, myngp
cdef DTYPE_t shifter cdef DTYPE_t shifter
myperiodic = periodic myperiodic = periodic
myngp = ngp
if centered: if centered:
shifter = Lbox/2 shifter = Lbox/2
@ -193,26 +250,46 @@ def interp3d(x not None, y not None,
in_slice = d in_slice = d
Nelt = ax.size Nelt = ax.size
with nogil: with nogil:
if not myngp:
if myperiodic: if myperiodic:
for i in xrange(Nelt): for i in prange(Nelt):
if interp3d_INTERNAL_periodic(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: if interp3d_INTERNAL_periodic(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0:
with gil: with gil:
raise ierror raise ierror
else: else:
for i in xrange(Nelt): for i in prange(Nelt):
if interp3d_INTERNAL(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: if interp3d_INTERNAL(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0:
with gil: with gil:
raise ierror raise ierror
else:
if myperiodic:
for i in prange(Nelt):
if ngp3d_INTERNAL_periodic(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0:
with gil:
raise ierror
else:
for i in prange(Nelt):
if ngp3d_INTERNAL(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0:
with gil:
raise ierror
return out return out
else: else:
if not myngp:
if periodic: if periodic:
if interp3d_INTERNAL_periodic(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0: if interp3d_INTERNAL_periodic(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0:
raise ierror raise ierror
else: else:
if interp3d_INTERNAL(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0: if interp3d_INTERNAL(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0:
raise ierror raise ierror
else:
if periodic:
if ngp3d_INTERNAL_periodic(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0:
raise ierror
else:
if ngp3d_INTERNAL(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0:
raise ierror
return retval return retval
@cython.boundscheck(False) @cython.boundscheck(False)
@cython.cdivision(True) @cython.cdivision(True)
cdef DTYPE_t interp2d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, cdef DTYPE_t interp2d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
@ -376,12 +453,17 @@ cdef void INTERNAL_project_cic_no_mass_periodic(npx.ndarray[DTYPE_t, ndim=3] g,
cdef double a[3], c[3] cdef double a[3], c[3]
cdef int b[3], b1[3] cdef int b[3], b1[3]
cdef int do_not_put cdef int do_not_put
cdef DTYPE_t[:,:] ax
cdef DTYPE_t[:,:,:] ag
ax = x
ag = g
for i in range(x.shape[0]): for i in range(x.shape[0]):
do_not_put = 0 do_not_put = 0
for j in range(3): for j in range(3):
a[j] = (x[i,j]+half_Box)*delta_Box a[j] = (ax[i,j]+half_Box)*delta_Box
b[j] = int(floor(a[j])) b[j] = int(floor(a[j]))
b1[j] = (b[j]+1) % Ngrid b1[j] = (b[j]+1) % Ngrid
@ -390,18 +472,15 @@ cdef void INTERNAL_project_cic_no_mass_periodic(npx.ndarray[DTYPE_t, ndim=3] g,
b[j] %= Ngrid b[j] %= Ngrid
assert b[j] >= 0 and b[j] < Ngrid ag[b[0],b[1],b[2]] += c[0]*c[1]*c[2]
assert b1[j] >= 0 and b1[j] < Ngrid ag[b1[0],b[1],b[2]] += a[0]*c[1]*c[2]
ag[b[0],b1[1],b[2]] += c[0]*a[1]*c[2]
ag[b1[0],b1[1],b[2]] += a[0]*a[1]*c[2]
g[b[0],b[1],b[2]] += c[0]*c[1]*c[2] ag[b[0],b[1],b1[2]] += c[0]*c[1]*a[2]
g[b1[0],b[1],b[2]] += a[0]*c[1]*c[2] ag[b1[0],b[1],b1[2]] += a[0]*c[1]*a[2]
g[b[0],b1[1],b[2]] += c[0]*a[1]*c[2] ag[b[0],b1[1],b1[2]] += c[0]*a[1]*a[2]
g[b1[0],b1[1],b[2]] += a[0]*a[1]*c[2] ag[b1[0],b1[1],b1[2]] += a[0]*a[1]*a[2]
g[b[0],b[1],b1[2]] += c[0]*c[1]*a[2]
g[b1[0],b[1],b1[2]] += a[0]*c[1]*a[2]
g[b[0],b1[1],b1[2]] += c[0]*a[1]*a[2]
g[b1[0],b1[1],b1[2]] += a[0]*a[1]*a[2]
@cython.boundscheck(False) @cython.boundscheck(False)

View File

@ -1,7 +1,12 @@
from _cosmotool import * from _cosmotool import *
from _project import * from _project import *
from grafic import writeGrafic, writeWhitePhase, readGrafic, readWhitePhase from .grafic import writeGrafic, writeWhitePhase, readGrafic, readWhitePhase
from borg import read_borg_vol from .borg import read_borg_vol
from cic import cicParticles from .cic import cicParticles
from simu import loadRamsesAll, simpleWriteGadget, SimulationBare from .simu import loadRamsesAll, simpleWriteGadget, SimulationBare
from timing import time_block, timeit, timeit_quiet from .timing import time_block, timeit, timeit_quiet
try:
from .fftw import CubeFT
except ImportError:
print("No FFTW support")

36
python/cosmotool/fftw.py Normal file
View File

@ -0,0 +1,36 @@
import pyfftw
import multiprocessing
import numpy as np
import numexpr as ne
class CubeFT(object):
def __init__(self, L, N, max_cpu=-1):
self.N = N
self.align = pyfftw.simd_alignment
self.L = L
self.max_cpu = multiprocessing.cpu_count() if max_cpu < 0 else max_cpu
self._dhat = pyfftw.n_byte_align_empty((self.N,self.N,self.N/2+1), self.align, dtype='complex64')
self._density = pyfftw.n_byte_align_empty((self.N,self.N,self.N), self.align, dtype='float32')
self._irfft = pyfftw.FFTW(self._dhat, self._density, axes=(0,1,2), direction='FFTW_BACKWARD', threads=self.max_cpu, normalize_idft=False)
self._rfft = pyfftw.FFTW(self._density, self._dhat, axes=(0,1,2), threads=self.max_cpu, normalize_idft=False)
def rfft(self):
return ne.evaluate('c*a', out=self._dhat, local_dict={'c':self._rfft(normalise_idft=False),'a':(self.L/self.N)**3}, casting='unsafe')
def irfft(self):
return ne.evaluate('c*a', out=self._density, local_dict={'c':self._irfft(normalise_idft=False),'a':(1/self.L)**3}, casting='unsafe')
def get_dhat(self):
return self._dhat
def set_dhat(self, in_dhat):
self._dhat[:] = in_dhat
dhat = property(get_dhat, set_dhat, None)
def get_density(self):
return self._density
def set_density(self, d):
self._density[:] = d
density = property(get_density, set_density, None)

View File

@ -103,7 +103,7 @@ def run_generation(input_borg, a_borg, a_ic, cosmo, supersample=1, do_lpt2=True,
@ct.timeit_quiet @ct.timeit_quiet
def whitify(density, L, cosmo, supergenerate=1, func='HU_WIGGLES'): def whitify(density, L, cosmo, supergenerate=1, zero_fill=False, func='HU_WIGGLES'):
N = density.shape[0] N = density.shape[0]
p = ct.CosmologyPower(**cosmo) p = ct.CosmologyPower(**cosmo)
@ -141,6 +141,10 @@ def whitify(density, L, cosmo, supergenerate=1, func='HU_WIGGLES'):
if supergenerate > 1: if supergenerate > 1:
cond=np.isnan(density_hat_super) cond=np.isnan(density_hat_super)
if zero_fill:
density_hat_super[cond] = 0
else:
print np.where(np.isnan(density_hat_super))[0].size print np.where(np.isnan(density_hat_super))[0].size
Nz = np.count_nonzero(cond) Nz = np.count_nonzero(cond)
density_hat_super.real[cond] = np.random.randn(Nz) density_hat_super.real[cond] = np.random.randn(Nz)
@ -166,8 +170,8 @@ def write_icfiles(*generated_ic, **kwargs):
"""Write the initial conditions from the tuple returned by run_generation""" """Write the initial conditions from the tuple returned by run_generation"""
supergenerate=1 supergenerate=1
if 'supergenerate' in kwargs: supergenerate=kwargs.get('supergenerate', 1)
supergenerate=kwargs['supergenerate'] zero_fill=kwargs.get('zero_fill', False)
posx,vel,density,N,L,a_ic,cosmo = generated_ic posx,vel,density,N,L,a_ic,cosmo = generated_ic
ct.simpleWriteGadget("Data/borg.gad", posx, velocities=vel, boxsize=L, Hubble=cosmo['h'], Omega_M=cosmo['omega_M_0'], time=a_ic) ct.simpleWriteGadget("Data/borg.gad", posx, velocities=vel, boxsize=L, Hubble=cosmo['h'], Omega_M=cosmo['omega_M_0'], time=a_ic)
@ -176,7 +180,7 @@ def write_icfiles(*generated_ic, **kwargs):
ct.writeGrafic("Data/ic_deltab", density, L, a_ic, **cosmo) ct.writeGrafic("Data/ic_deltab", density, L, a_ic, **cosmo)
ct.writeWhitePhase("Data/white.dat", whitify(density, L, cosmo, supergenerate=supergenerate)) ct.writeWhitePhase("Data/white.dat", whitify(density, L, cosmo, supergenerate=supergenerate,zero_fill=zero_fill))
with file("Data/white_params", mode="w") as f: with file("Data/white_params", mode="w") as f:
f.write("4\n%lg, %lg, %lg\n" % (cosmo['omega_M_0'], cosmo['omega_lambda_0'], 100*cosmo['h'])) f.write("4\n%lg, %lg, %lg\n" % (cosmo['omega_M_0'], cosmo['omega_lambda_0'], 100*cosmo['h']))

View File

@ -9,10 +9,11 @@ cosmo['omega_B_0']=0.049
cosmo['SIGMA8']=0.8344 cosmo['SIGMA8']=0.8344
cosmo['ns']=0.9624 cosmo['ns']=0.9624
supergen=1 supergen=2
zstart=50 zstart=50
astart=1/(1.+zstart) astart=1/69.#1/(1.+zstart)
halfPixelShift=False halfPixelShift=False
zero_fill=True
if __name__=="__main__": if __name__=="__main__":
bic.write_icfiles(*bic.run_generation("initial_density_988.dat", 0.001, astart, cosmo, supersample=1, shiftPixel=halfPixelShift, do_lpt2=False), supergenerate=supergen) bic.write_icfiles(*bic.run_generation("initial_density_1380.dat", 0.001, astart, cosmo, supersample=1, shiftPixel=halfPixelShift, do_lpt2=False), supergenerate=supergen, zero_fill=zero_fill)