Parallelized interp3d
This commit is contained in:
parent
8809d6c255
commit
0001d86977
1 changed files with 34 additions and 16 deletions
|
@ -23,7 +23,7 @@ cdef extern from "project_tool.hpp" namespace "":
|
|||
@cython.cdivision(True)
|
||||
cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
|
||||
DTYPE_t z,
|
||||
npx.ndarray[DTYPE_t, ndim=3] d, DTYPE_t Lbox) except? 0:
|
||||
DTYPE_t[:,:,:] d, DTYPE_t Lbox) nogil except? 0:
|
||||
|
||||
cdef int Ngrid = d.shape[0]
|
||||
cdef DTYPE_t inv_delta = Ngrid/Lbox
|
||||
|
@ -60,9 +60,15 @@ cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
|
|||
iy = iy%Ngrid
|
||||
iz = iz%Ngrid
|
||||
|
||||
assert ((ix >= 0) and ((jx) < Ngrid))
|
||||
assert ((iy >= 0) and ((jy) < Ngrid))
|
||||
assert ((iz >= 0) and ((jz) < Ngrid))
|
||||
if (ix < 0) or (jx >= Ngrid):
|
||||
with gil:
|
||||
assert ((ix >= 0) and ((jx) < Ngrid))
|
||||
if (iy < 0) or (jy >= Ngrid):
|
||||
with gil:
|
||||
assert ((iy >= 0) and ((jy) < Ngrid))
|
||||
if (iz < 0) or (jz >= Ngrid):
|
||||
with gil:
|
||||
assert ((iz >= 0) and ((jz) < Ngrid))
|
||||
|
||||
f[0][0][0] = (1-rx)*(1-ry)*(1-rz)
|
||||
f[1][0][0] = ( rx)*(1-ry)*(1-rz)
|
||||
|
@ -89,7 +95,7 @@ cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y,
|
|||
@cython.cdivision(True)
|
||||
cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y,
|
||||
DTYPE_t z,
|
||||
npx.ndarray[DTYPE_t, ndim=3] d, DTYPE_t Lbox) except? 0:
|
||||
DTYPE_t[:,:,:] d, DTYPE_t Lbox) nogil except? 0:
|
||||
|
||||
cdef int Ngrid = d.shape[0]
|
||||
cdef DTYPE_t inv_delta = Ngrid/Lbox
|
||||
|
@ -110,11 +116,14 @@ cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y,
|
|||
rz -= iz
|
||||
|
||||
if ((ix < 0) or (ix+1) >= Ngrid):
|
||||
raise IndexError("X coord out of bound (ix=%d, x=%g)" % (ix,x))
|
||||
with gil:
|
||||
raise IndexError("X coord out of bound (ix=%d, x=%g)" % (ix,x))
|
||||
if ((iy < 0) or (iy+1) >= Ngrid):
|
||||
raise IndexError("Y coord out of bound (iy=%d, y=%g)" % (iy,y))
|
||||
with gil:
|
||||
raise IndexError("Y coord out of bound (iy=%d, y=%g)" % (iy,y))
|
||||
if ((iz < 0) or (iz+1) >= Ngrid):
|
||||
raise IndexError("Z coord out of bound (iz=%d, z=%g)" % (iz,z))
|
||||
with gil:
|
||||
raise IndexError("Z coord out of bound (iz=%d, z=%g)" % (iz,z))
|
||||
# assert ((ix >= 0) and ((ix+1) < Ngrid))
|
||||
# assert ((iy >= 0) and ((iy+1) < Ngrid))
|
||||
# assert ((iz >= 0) and ((iz+1) < Ngrid))
|
||||
|
@ -139,6 +148,7 @@ cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y,
|
|||
d[ix ,iy+1,iz+1] * f[0][1][1] + \
|
||||
d[ix+1,iy+1,iz+1] * f[1][1][1]
|
||||
|
||||
@cython.boundscheck(False)
|
||||
def interp3d(x not None, y not None,
|
||||
z not None,
|
||||
npx.ndarray[DTYPE_t, ndim=3] d not None, DTYPE_t Lbox,
|
||||
|
@ -148,8 +158,13 @@ def interp3d(x not None, y not None,
|
|||
Compute the tri-linear interpolation of the given field (d) at the given position (x,y,z). It assumes that they are box-centered coordinates. So (x,y,z) == (0,0,0) is equivalent to the pixel at (Nx/2,Ny/2,Nz/2) with Nx,Ny,Nz = d.shape. If periodic is set, it assumes the box is periodic
|
||||
"""
|
||||
cdef npx.ndarray[DTYPE_t] out
|
||||
cdef npx.ndarray[DTYPE_t] ax, ay, az
|
||||
cdef int i
|
||||
cdef DTYPE_t[:] out_slice
|
||||
cdef DTYPE_t[:] ax, ay, az
|
||||
cdef long i
|
||||
cdef long Nelt
|
||||
cdef int myperiodic
|
||||
|
||||
myperiodic = periodic
|
||||
|
||||
if d.shape[0] != d.shape[1] or d.shape[0] != d.shape[2]:
|
||||
raise ValueError("Grid must have a cubic shape")
|
||||
|
@ -164,12 +179,15 @@ def interp3d(x not None, y not None,
|
|||
assert ax.size == ay.size and ax.size == az.size
|
||||
|
||||
out = np.empty(x.shape, dtype=DTYPE)
|
||||
if periodic:
|
||||
for i in range(ax.size):
|
||||
out[i] = interp3d_INTERNAL_periodic(ax[i], ay[i], az[i], d, Lbox)
|
||||
else:
|
||||
for i in range(ax.size):
|
||||
out[i] = interp3d_INTERNAL(ax[i], ay[i], az[i], d, Lbox)
|
||||
out_slice = out
|
||||
Nelt = ax.size
|
||||
with nogil:
|
||||
if myperiodic:
|
||||
for i in prange(Nelt):
|
||||
out[i] = interp3d_INTERNAL_periodic(ax[i], ay[i], az[i], d, Lbox)
|
||||
else:
|
||||
for i in prange(Nelt):
|
||||
out[i] = interp3d_INTERNAL(ax[i], ay[i], az[i], d, Lbox)
|
||||
|
||||
return out
|
||||
else:
|
||||
|
|
Loading…
Reference in a new issue