diff --git a/python/_project.pyx b/python/_project.pyx index 317f001..0906474 100644 --- a/python/_project.pyx +++ b/python/_project.pyx @@ -82,6 +82,76 @@ cdef int interp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, return 0 +@cython.boundscheck(False) +@cython.cdivision(True) +cdef int ngp3d_INTERNAL_periodic(DTYPE_t x, DTYPE_t y, + DTYPE_t z, + DTYPE_t[:,:,:] d, DTYPE_t Lbox, DTYPE_t *retval) nogil: + + cdef int Ngrid = d.shape[0] + cdef DTYPE_t inv_delta = Ngrid/Lbox + cdef int ix, iy, iz + cdef DTYPE_t f[2][2][2] + cdef DTYPE_t rx, ry, rz + cdef int jx, jy, jz + + rx = (inv_delta*x) + ry = (inv_delta*y) + rz = (inv_delta*z) + + ix = int(floor(rx)) + iy = int(floor(ry)) + iz = int(floor(rz)) + + + ix = ix%Ngrid + iy = iy%Ngrid + iz = iz%Ngrid + + if (ix < 0) or (jx >= Ngrid): + return -1 + if (iy < 0) or (jy >= Ngrid): + return -2 + if (iz < 0) or (jz >= Ngrid): + return -3 + + retval[0] = d[ix ,iy ,iz ] + + return 0 + + +@cython.boundscheck(False) +@cython.cdivision(True) +cdef int ngp3d_INTERNAL(DTYPE_t x, DTYPE_t y, + DTYPE_t z, + DTYPE_t[:,:,:] d, DTYPE_t Lbox, DTYPE_t *retval) nogil: + + cdef int Ngrid = d.shape[0] + cdef DTYPE_t inv_delta = Ngrid/Lbox + cdef int ix, iy, iz + cdef DTYPE_t f[2][2][2] + cdef DTYPE_t rx, ry, rz + cdef int jx, jy, jz + + rx = (inv_delta*x) + ry = (inv_delta*y) + rz = (inv_delta*z) + + ix = int(floor(rx)) + iy = int(floor(ry)) + iz = int(floor(rz)) + + if (ix < 0) or (jx >= Ngrid): + return -1 + if (iy < 0) or (jy >= Ngrid): + return -2 + if (iz < 0) or (jz >= Ngrid): + return -3 + + retval[0] = d[ix ,iy ,iz ] + + return 0 + @cython.boundscheck(False) @cython.cdivision(True) @@ -145,8 +215,8 @@ cdef int interp3d_INTERNAL(DTYPE_t x, DTYPE_t y, def interp3d(x not None, y not None, z not None, npx.ndarray[DTYPE_t, ndim=3] d not None, DTYPE_t Lbox, - bool periodic=False, bool centered=True): - """ interp3d(x,y,z,d,Lbox,periodic=False) -> interpolated values + bool periodic=False, bool centered=True, bool ngp=False): + """ interp3d(x,y,z,d,Lbox,periodic=False,centered=True,ngp=False) -> interpolated values Compute the tri-linear interpolation of the given field (d) at the given position (x,y,z). It assumes that they are box-centered coordinates. So (x,y,z) == (0,0,0) is equivalent to the pixel at (Nx/2,Ny/2,Nz/2) with Nx,Ny,Nz = d.shape. If periodic is set, it assumes the box is periodic """ @@ -157,10 +227,11 @@ def interp3d(x not None, y not None, cdef DTYPE_t retval cdef long i cdef long Nelt - cdef int myperiodic + cdef int myperiodic, myngp cdef DTYPE_t shifter myperiodic = periodic + myngp = ngp if centered: shifter = Lbox/2 @@ -186,17 +257,28 @@ def interp3d(x not None, y not None, in_slice = d Nelt = ax.size with nogil: - if myperiodic: - for i in prange(Nelt): - if interp3d_INTERNAL_periodic(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: - with gil: - raise ierror + if myngp: + if myperiodic: + for i in prange(Nelt): + if interp3d_INTERNAL_periodic(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: + with gil: + raise ierror + else: + for i in prange(Nelt): + if interp3d_INTERNAL(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: + with gil: + raise ierror else: - for i in prange(Nelt): - if interp3d_INTERNAL(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: - with gil: - raise ierror - + if myperiodic: + for i in prange(Nelt): + if ngp3d_INTERNAL_periodic(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: + with gil: + raise ierror + else: + for i in prange(Nelt): + if ngp3d_INTERNAL(shifter+ax[i], shifter+ay[i], shifter+az[i], in_slice, Lbox, &out_slice[i]) < 0: + with gil: + raise ierror return out else: if periodic: