From 16888967c105848f484df10c77678a0c51b371d6 Mon Sep 17 00:00:00 2001 From: Dag Sverre Seljebotn Date: Tue, 19 May 2015 20:53:52 +0200 Subject: [PATCH] Support for multiple transforms and spin-transforms --- python/libsharp/libsharp.pyx | 43 ++++++++++++++++++++----------- python/libsharp/tests/test_sht.py | 5 ++-- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/python/libsharp/libsharp.pyx b/python/libsharp/libsharp.pyx index 3144e03..ddfcd98 100644 --- a/python/libsharp/libsharp.pyx +++ b/python/libsharp/libsharp.pyx @@ -54,16 +54,25 @@ JOBTYPE_TO_CONST = { 'YtW': SHARP_YtW } - -def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input, +def sht(jobtype, geom_info ginfo, alm_info ainfo, double[:, :, ::1] input, int spin=0, comm=None, add=False): cdef void *comm_ptr cdef int flags = SHARP_DP | (SHARP_ADD if add else 0) - cdef double *palm - cdef double *pmap cdef int r cdef sharp_jobtype jobtype_i - cdef double[::1] output_buf + cdef double[:, :, ::1] output_buf + cdef int ntrans = input.shape[0] * input.shape[1] + cdef int i, j + + if spin == 0 and input.shape[1] != 1: + raise ValueError('For spin == 0, we need input.shape[1] == 1') + elif spin != 0 and input.shape[1] != 2: + raise ValueError('For spin != 0, we need input.shape[1] == 2') + + + cdef size_t[::1] ptrbuf = np.empty(2 * ntrans, dtype=np.uintp) + cdef double **alm_ptrs = &ptrbuf[0] + cdef double **map_ptrs = &ptrbuf[ntrans] try: jobtype_i = JOBTYPE_TO_CONST[jobtype] @@ -71,23 +80,27 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input, raise ValueError('jobtype must be one of: %s' % ', '.join(sorted(JOBTYPE_TO_CONST.keys()))) if jobtype_i == SHARP_Y or jobtype_i == SHARP_WY: - output = np.empty(ginfo.local_size(), dtype=np.float64) + output = np.empty((input.shape[0], input.shape[1], ginfo.local_size()), dtype=np.float64) output_buf = output - pmap = &output_buf[0] - palm = &input[0] + for i in range(input.shape[0]): + for j in range(input.shape[1]): + alm_ptrs[i * input.shape[1] + j] = &input[i, j, 0] + map_ptrs[i * input.shape[1] + j] = &output_buf[i, j, 0] else: - output = np.empty(ainfo.local_size(), dtype=np.float64) + output = np.empty((input.shape[0], input.shape[1], ainfo.local_size()), dtype=np.float64) output_buf = output - pmap = &input[0] - palm = &output_buf[0] + for i in range(input.shape[0]): + for j in range(input.shape[1]): + alm_ptrs[i * input.shape[1] + j] = &output_buf[i, j, 0] + map_ptrs[i * input.shape[1] + j] = &input[i, j, 0] if comm is None: with nogil: sharp_execute ( jobtype_i, geom_info=ginfo.ginfo, alm_info=ainfo.ainfo, - spin=spin, alm=&palm, map=&pmap, - ntrans=1, flags=flags, time=NULL, opcnt=NULL) + spin=spin, alm=alm_ptrs, map=map_ptrs, + ntrans=ntrans, flags=flags, time=NULL, opcnt=NULL) else: from mpi4py import MPI if not isinstance(comm, MPI.Comm): @@ -97,8 +110,8 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input, r = sharp_execute_mpi_maybe ( comm_ptr, jobtype_i, geom_info=ginfo.ginfo, alm_info=ainfo.ainfo, - spin=spin, alm=&palm, map=&pmap, - ntrans=1, flags=flags, time=NULL, opcnt=NULL) + spin=spin, alm=alm_ptrs, map=map_ptrs, + ntrans=ntrans, flags=flags, time=NULL, opcnt=NULL) if r == SHARP_ERROR_NO_MPI: raise Exception('MPI requested, but not available') diff --git a/python/libsharp/tests/test_sht.py b/python/libsharp/tests/test_sht.py index e97b90f..459f446 100644 --- a/python/libsharp/tests/test_sht.py +++ b/python/libsharp/tests/test_sht.py @@ -25,8 +25,9 @@ def test_basic(): alm[0] = 1 - - map = libsharp.synthesis(grid, order, alm, comm=MPI.COMM_WORLD) + map = libsharp.synthesis(grid, order, np.repeat(alm[None, None, :], 3, 0), comm=MPI.COMM_WORLD) + assert np.all(map[2, :] == map[1, :]) and np.all(map[1, :] == map[0, :]) + map = map[0, 0, :] if rank == 0: healpy.mollzoom(map) from matplotlib.pyplot import show