From 6150597c67a94d2749f8db1e115db00fb1f836bc Mon Sep 17 00:00:00 2001 From: Guilhem Lavaux Date: Thu, 10 Jul 2014 16:19:40 +0200 Subject: [PATCH] OpenMP of _project --- python/_project.pyx | 96 +++++++++++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 39 deletions(-) diff --git a/python/_project.pyx b/python/_project.pyx index 4f69953..2de1834 100644 --- a/python/_project.pyx +++ b/python/_project.pyx @@ -21,9 +21,9 @@ cdef extern from "project_tool.hpp" namespace "": @cython.boundscheck(False) @cython.cdivision(True) -cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, +cdef int interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, DTYPE_t z, - DTYPE_t[:,:,:] d, DTYPE_t Lbox) nogil except? 0: + DTYPE_t[:,:,:] d, DTYPE_t Lbox, DTYPE_t *retval) nogil: cdef int Ngrid = d.shape[0] cdef DTYPE_t inv_delta = Ngrid/Lbox @@ -32,9 +32,9 @@ cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, cdef DTYPE_t rx, ry, rz cdef int jx, jy, jz - rx = (inv_delta*x + Ngrid/2) - ry = (inv_delta*y + Ngrid/2) - rz = (inv_delta*z + Ngrid/2) + rx = (inv_delta*x) + ry = (inv_delta*y) + rz = (inv_delta*z) ix = int(floor(rx)) iy = int(floor(ry)) @@ -61,14 +61,11 @@ cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, iz = iz%Ngrid if (ix < 0) or (jx >= Ngrid): - with gil: - assert ((ix >= 0) and ((jx) < Ngrid)) + return -1 if (iy < 0) or (jy >= Ngrid): - with gil: - assert ((iy >= 0) and ((jy) < Ngrid)) + return -2 if (iz < 0) or (jz >= Ngrid): - with gil: - assert ((iz >= 0) and ((jz) < Ngrid)) + return -3 f[0][0][0] = (1-rx)*(1-ry)*(1-rz) f[1][0][0] = ( rx)*(1-ry)*(1-rz) @@ -80,7 +77,7 @@ cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, f[0][1][1] = (1-rx)*( ry)*( rz) f[1][1][1] = ( rx)*( ry)*( rz) - return \ + retval[0] = \ d[ix ,iy ,iz ] * f[0][0][0] + \ d[jx ,iy ,iz ] * f[1][0][0] + \ d[ix ,jy ,iz ] * f[0][1][0] + \ @@ -90,12 +87,14 @@ cdef DTYPE_t interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, d[ix ,jy ,jz ] * f[0][1][1] + \ d[jx ,jy ,jz ] * f[1][1][1] + return 0 + @cython.boundscheck(False) @cython.cdivision(True) -cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y, +cdef int interp3d_INTERNAL(DTYPE_t x, DTYPE_t y, DTYPE_t z, - DTYPE_t[:,:,:] d, DTYPE_t Lbox) nogil except? 0: + DTYPE_t[:,:,:] d, DTYPE_t Lbox, DTYPE_t *retval) nogil: cdef int Ngrid = d.shape[0] cdef DTYPE_t inv_delta = Ngrid/Lbox @@ -103,9 +102,9 @@ cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y, cdef DTYPE_t f[2][2][2] cdef DTYPE_t rx, ry, rz - rx = (inv_delta*x + Ngrid/2) - ry = (inv_delta*y + Ngrid/2) - rz = (inv_delta*z + Ngrid/2) + rx = (inv_delta*x) + ry = (inv_delta*y) + rz = (inv_delta*z) ix = int(floor(rx)) iy = int(floor(ry)) @@ -116,14 +115,13 @@ cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y, rz -= iz if ((ix < 0) or (ix+1) >= Ngrid): - with gil: - raise IndexError("X coord out of bound (ix=%d, x=%g)" % (ix,x)) + return -1 + if ((iy < 0) or (iy+1) >= Ngrid): - with gil: - raise IndexError("Y coord out of bound (iy=%d, y=%g)" % (iy,y)) + return -2 + if ((iz < 0) or (iz+1) >= Ngrid): - with gil: - raise IndexError("Z coord out of bound (iz=%d, z=%g)" % (iz,z)) + return -3 # assert ((ix >= 0) and ((ix+1) < Ngrid)) # assert ((iy >= 0) and ((iy+1) < Ngrid)) # assert ((iz >= 0) and ((iz+1) < Ngrid)) @@ -138,7 +136,7 @@ cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y, f[0][1][1] = (1-rx)*( ry)*( rz) f[1][1][1] = ( rx)*( ry)*( rz) - return \ + retval[0] = \ d[ix ,iy ,iz ] * f[0][0][0] + \ d[ix+1,iy ,iz ] * f[1][0][0] + \ d[ix ,iy+1,iz ] * f[0][1][0] + \ @@ -148,11 +146,13 @@ cdef DTYPE_t interp3d_INTERNAL(DTYPE_t x, DTYPE_t y, d[ix ,iy+1,iz+1] * f[0][1][1] + \ d[ix+1,iy+1,iz+1] * f[1][1][1] + return 0 + @cython.boundscheck(False) def interp3d(x not None, y not None, z not None, npx.ndarray[DTYPE_t, ndim=3] d not None, DTYPE_t Lbox, - bool periodic=False): + bool periodic=False, bool centered=True): """ interp3d(x,y,z,d,Lbox,periodic=False) -> interpolated values Compute the tri-linear interpolation of the given field (d) at the given position (x,y,z). It assumes that they are box-centered coordinates. So (x,y,z) == (0,0,0) is equivalent to the pixel at (Nx/2,Ny/2,Nz/2) with Nx,Ny,Nz = d.shape. If periodic is set, it assumes the box is periodic @@ -160,15 +160,25 @@ def interp3d(x not None, y not None, cdef npx.ndarray[DTYPE_t] out cdef DTYPE_t[:] out_slice cdef DTYPE_t[:] ax, ay, az + cdef DTYPE_t[:,:,:] in_slice + cdef DTYPE_t retval cdef long i cdef long Nelt cdef int myperiodic + cdef DTYPE_t shifter myperiodic = periodic + if centered: + shifter = Lbox/2 + else: + shifter = 0 + if d.shape[0] != d.shape[1] or d.shape[0] != d.shape[2]: raise ValueError("Grid must have a cubic shape") + + ierror = IndexError("Interpolating outside range") if type(x) == np.ndarray or type(y) == np.ndarray or type(z) == np.ndarray: if type(x) != np.ndarray or type(y) != np.ndarray or type(z) != np.ndarray: raise ValueError("All or no array. No partial arguments") @@ -180,22 +190,29 @@ def interp3d(x not None, y not None, out = np.empty(x.shape, dtype=DTYPE) out_slice = out + in_slice = d Nelt = ax.size with nogil: if myperiodic: - for i in prange(Nelt): - out[i] = interp3d_INTERNAL_periodic(ax[i], ay[i], az[i], d, Lbox) + for i in xrange(Nelt): + if interp3d_INTERNAL_periodic(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: + with gil: + raise ierror else: - for i in prange(Nelt): - out[i] = interp3d_INTERNAL(ax[i], ay[i], az[i], d, Lbox) + for i in xrange(Nelt): + if interp3d_INTERNAL(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: + with gil: + raise ierror return out else: if periodic: - return interp3d_INTERNAL_periodic(x, y, z, d, Lbox) + if interp3d_INTERNAL_periodic(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0: + raise ierror else: - return interp3d_INTERNAL(x, y, z, d, Lbox) - + if interp3d_INTERNAL(shifter+x, shifter+y, shifter+z, d, Lbox, &retval) < 0: + raise ierror + return retval @cython.boundscheck(False) @cython.cdivision(True) cdef DTYPE_t interp2d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, @@ -326,7 +343,7 @@ cdef void INTERNAL_project_cic_no_mass(npx.ndarray[DTYPE_t, ndim=3] g, for i in range(x.shape[0]): - do_not_put = False + do_not_put = 0 for j in range(3): a[j] = (x[i,j]+half_Box)*delta_Box b[j] = int(floor(a[j])) @@ -362,19 +379,20 @@ cdef void INTERNAL_project_cic_no_mass_periodic(npx.ndarray[DTYPE_t, ndim=3] g, for i in range(x.shape[0]): - do_not_put = False + do_not_put = 0 for j in range(3): a[j] = (x[i,j]+half_Box)*delta_Box b[j] = int(floor(a[j])) - b1[j] = b[j]+1 - while b1[j] < 0: - b1[j] += Ngrid - while b1[j] >= Ngrid: - b1[j] -= Ngrid + b1[j] = (b[j]+1) % Ngrid a[j] -= b[j] c[j] = 1-a[j] + b[j] %= Ngrid + + assert b[j] >= 0 and b[j] < Ngrid + assert b1[j] >= 0 and b1[j] < Ngrid + g[b[0],b[1],b[2]] += c[0]*c[1]*c[2] g[b1[0],b[1],b[2]] += a[0]*c[1]*c[2] g[b[0],b1[1],b[2]] += c[0]*a[1]*c[2]