From 0001d86977e3fd46a5ba37ff71d68ea661e5d955 Mon Sep 17 00:00:00 2001 From: Guilhem Lavaux Date: Wed, 9 Jul 2014 13:58:48 +0200 Subject: [PATCH] Parallelized interp3d --- python/_project.pyx | 50 ++++++++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/python/_project.pyx b/python/_project.pyx index a8ff9dc..4f69953 100644 --- a/python/_project.pyx +++ b/python/_project.pyx @@ -23,7 +23,7 @@ cdef extern from "project_tool.hpp" namespace "": @cython.cdivision(True) cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, DTYPE_t z, - npx.ndarray[DTYPE_t, ndim=3] d, DTYPE_t Lbox) except? 0: + DTYPE_t[:,:,:] d, DTYPE_t Lbox) nogil except? 0: cdef int Ngrid = d.shape[0] cdef DTYPE_t inv_delta = Ngrid/Lbox @@ -60,9 +60,15 @@ cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, iy = iy%Ngrid iz = iz%Ngrid - assert ((ix >= 0) and ((jx) < Ngrid)) - assert ((iy >= 0) and ((jy) < Ngrid)) - assert ((iz >= 0) and ((jz) < Ngrid)) + if (ix < 0) or (jx >= Ngrid): + with gil: + assert ((ix >= 0) and ((jx) < Ngrid)) + if (iy < 0) or (jy >= Ngrid): + with gil: + assert ((iy >= 0) and ((jy) < Ngrid)) + if (iz < 0) or (jz >= Ngrid): + with gil: + assert ((iz >= 0) and ((jz) < Ngrid)) f[0][0][0] = (1-rx)*(1-ry)*(1-rz) f[1][0][0] = ( rx)*(1-ry)*(1-rz) @@ -89,7 +95,7 @@ cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, @cython.cdivision(True) cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y, DTYPE_t z, - npx.ndarray[DTYPE_t, ndim=3] d, DTYPE_t Lbox) except? 0: + DTYPE_t[:,:,:] d, DTYPE_t Lbox) nogil except? 0: cdef int Ngrid = d.shape[0] cdef DTYPE_t inv_delta = Ngrid/Lbox @@ -110,11 +116,14 @@ cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y, rz -= iz if ((ix < 0) or (ix+1) >= Ngrid): - raise IndexError("X coord out of bound (ix=%d, x=%g)" % (ix,x)) + with gil: + raise IndexError("X coord out of bound (ix=%d, x=%g)" % (ix,x)) if ((iy < 0) or (iy+1) >= Ngrid): - raise IndexError("Y coord out of bound (iy=%d, y=%g)" % (iy,y)) + with gil: + raise IndexError("Y coord out of bound (iy=%d, y=%g)" % (iy,y)) if ((iz < 0) or (iz+1) >= Ngrid): - raise IndexError("Z coord out of bound (iz=%d, z=%g)" % (iz,z)) + with gil: + raise IndexError("Z coord out of bound (iz=%d, z=%g)" % (iz,z)) # assert ((ix >= 0) and ((ix+1) < Ngrid)) # assert ((iy >= 0) and ((iy+1) < Ngrid)) # assert ((iz >= 0) and ((iz+1) < Ngrid)) @@ -139,6 +148,7 @@ cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y, d[ix ,iy+1,iz+1] * f[0][1][1] + \ d[ix+1,iy+1,iz+1] * f[1][1][1] +@cython.boundscheck(False) def interp3d(x not None, y not None, z not None, npx.ndarray[DTYPE_t, ndim=3] d not None, DTYPE_t Lbox, @@ -148,8 +158,13 @@ def interp3d(x not None, y not None, Compute the tri-linear interpolation of the given field (d) at the given position (x,y,z). It assumes that they are box-centered coordinates. So (x,y,z) == (0,0,0) is equivalent to the pixel at (Nx/2,Ny/2,Nz/2) with Nx,Ny,Nz = d.shape. If periodic is set, it assumes the box is periodic """ cdef npx.ndarray[DTYPE_t] out - cdef npx.ndarray[DTYPE_t] ax, ay, az - cdef int i + cdef DTYPE_t[:] out_slice + cdef DTYPE_t[:] ax, ay, az + cdef long i + cdef long Nelt + cdef int myperiodic + + myperiodic = periodic if d.shape[0] != d.shape[1] or d.shape[0] != d.shape[2]: raise ValueError("Grid must have a cubic shape") @@ -164,12 +179,15 @@ def interp3d(x not None, y not None, assert ax.size == ay.size and ax.size == az.size out = np.empty(x.shape, dtype=DTYPE) - if periodic: - for i in range(ax.size): - out[i] = interp3d_INTERNAL_periodic(ax[i], ay[i], az[i], d, Lbox) - else: - for i in range(ax.size): - out[i] = interp3d_INTERNAL(ax[i], ay[i], az[i], d, Lbox) + out_slice = out + Nelt = ax.size + with nogil: + if myperiodic: + for i in prange(Nelt): + out[i] = interp3d_INTERNAL_periodic(ax[i], ay[i], az[i], d, Lbox) + else: + for i in prange(Nelt): + out[i] = interp3d_INTERNAL(ax[i], ay[i], az[i], d, Lbox) return out else: