Parallelized interp3d

This commit is contained in:
Guilhem Lavaux 2014-07-09 13:58:48 +02:00
parent 8809d6c255
commit 0001d86977

View File

@ -23,7 +23,7 @@ cdef extern from "project_tool.hpp" namespace "":
@cython.cdivision(True) @cython.cdivision(True)
cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
DTYPE_t z, DTYPE_t z,
npx.ndarray[DTYPE_t, ndim=3] d, DTYPE_t Lbox) except? 0: DTYPE_t[:,:,:] d, DTYPE_t Lbox) nogil except? 0:
cdef int Ngrid = d.shape[0] cdef int Ngrid = d.shape[0]
cdef DTYPE_t inv_delta = Ngrid/Lbox cdef DTYPE_t inv_delta = Ngrid/Lbox
@ -60,8 +60,14 @@ cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
iy = iy%Ngrid iy = iy%Ngrid
iz = iz%Ngrid iz = iz%Ngrid
if (ix < 0) or (jx >= Ngrid):
with gil:
assert ((ix >= 0) and ((jx) < Ngrid)) assert ((ix >= 0) and ((jx) < Ngrid))
if (iy < 0) or (jy >= Ngrid):
with gil:
assert ((iy >= 0) and ((jy) < Ngrid)) assert ((iy >= 0) and ((jy) < Ngrid))
if (iz < 0) or (jz >= Ngrid):
with gil:
assert ((iz >= 0) and ((jz) < Ngrid)) assert ((iz >= 0) and ((jz) < Ngrid))
f[0][0][0] = (1-rx)*(1-ry)*(1-rz) f[0][0][0] = (1-rx)*(1-ry)*(1-rz)
@ -89,7 +95,7 @@ cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
@cython.cdivision(True) @cython.cdivision(True)
cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y, cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y,
DTYPE_t z, DTYPE_t z,
npx.ndarray[DTYPE_t, ndim=3] d, DTYPE_t Lbox) except? 0: DTYPE_t[:,:,:] d, DTYPE_t Lbox) nogil except? 0:
cdef int Ngrid = d.shape[0] cdef int Ngrid = d.shape[0]
cdef DTYPE_t inv_delta = Ngrid/Lbox cdef DTYPE_t inv_delta = Ngrid/Lbox
@ -110,10 +116,13 @@ cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y,
rz -= iz rz -= iz
if ((ix < 0) or (ix+1) >= Ngrid): if ((ix < 0) or (ix+1) >= Ngrid):
with gil:
raise IndexError("X coord out of bound (ix=%d, x=%g)" % (ix,x)) raise IndexError("X coord out of bound (ix=%d, x=%g)" % (ix,x))
if ((iy < 0) or (iy+1) >= Ngrid): if ((iy < 0) or (iy+1) >= Ngrid):
with gil:
raise IndexError("Y coord out of bound (iy=%d, y=%g)" % (iy,y)) raise IndexError("Y coord out of bound (iy=%d, y=%g)" % (iy,y))
if ((iz < 0) or (iz+1) >= Ngrid): if ((iz < 0) or (iz+1) >= Ngrid):
with gil:
raise IndexError("Z coord out of bound (iz=%d, z=%g)" % (iz,z)) raise IndexError("Z coord out of bound (iz=%d, z=%g)" % (iz,z))
# assert ((ix >= 0) and ((ix+1) < Ngrid)) # assert ((ix >= 0) and ((ix+1) < Ngrid))
# assert ((iy >= 0) and ((iy+1) < Ngrid)) # assert ((iy >= 0) and ((iy+1) < Ngrid))
@ -139,6 +148,7 @@ cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y,
d[ix ,iy+1,iz+1] * f[0][1][1] + \ d[ix ,iy+1,iz+1] * f[0][1][1] + \
d[ix+1,iy+1,iz+1] * f[1][1][1] d[ix+1,iy+1,iz+1] * f[1][1][1]
@cython.boundscheck(False)
def interp3d(x not None, y not None, def interp3d(x not None, y not None,
z not None, z not None,
npx.ndarray[DTYPE_t, ndim=3] d not None, DTYPE_t Lbox, npx.ndarray[DTYPE_t, ndim=3] d not None, DTYPE_t Lbox,
@ -148,8 +158,13 @@ def interp3d(x not None, y not None,
Compute the tri-linear interpolation of the given field (d) at the given position (x,y,z). It assumes that they are box-centered coordinates. So (x,y,z) == (0,0,0) is equivalent to the pixel at (Nx/2,Ny/2,Nz/2) with Nx,Ny,Nz = d.shape. If periodic is set, it assumes the box is periodic Compute the tri-linear interpolation of the given field (d) at the given position (x,y,z). It assumes that they are box-centered coordinates. So (x,y,z) == (0,0,0) is equivalent to the pixel at (Nx/2,Ny/2,Nz/2) with Nx,Ny,Nz = d.shape. If periodic is set, it assumes the box is periodic
""" """
cdef npx.ndarray[DTYPE_t] out cdef npx.ndarray[DTYPE_t] out
cdef npx.ndarray[DTYPE_t] ax, ay, az cdef DTYPE_t[:] out_slice
cdef int i cdef DTYPE_t[:] ax, ay, az
cdef long i
cdef long Nelt
cdef int myperiodic
myperiodic = periodic
if d.shape[0] != d.shape[1] or d.shape[0] != d.shape[2]: if d.shape[0] != d.shape[1] or d.shape[0] != d.shape[2]:
raise ValueError("Grid must have a cubic shape") raise ValueError("Grid must have a cubic shape")
@ -164,11 +179,14 @@ def interp3d(x not None, y not None,
assert ax.size == ay.size and ax.size == az.size assert ax.size == ay.size and ax.size == az.size
out = np.empty(x.shape, dtype=DTYPE) out = np.empty(x.shape, dtype=DTYPE)
if periodic: out_slice = out
for i in range(ax.size): Nelt = ax.size
with nogil:
if myperiodic:
for i in prange(Nelt):
out[i] = interp3d_INTERNAL_periodic(ax[i], ay[i], az[i], d, Lbox) out[i] = interp3d_INTERNAL_periodic(ax[i], ay[i], az[i], d, Lbox)
else: else:
for i in range(ax.size): for i in prange(Nelt):
out[i] = interp3d_INTERNAL(ax[i], ay[i], az[i], d, Lbox) out[i] = interp3d_INTERNAL(ax[i], ay[i], az[i], d, Lbox)
return out return out