Support for multiple transforms and spin-transforms

This commit is contained in:
Dag Sverre Seljebotn 2015-05-19 20:53:52 +02:00
parent 48e213151a
commit 16888967c1
2 changed files with 31 additions and 17 deletions

View file

@ -54,16 +54,25 @@ JOBTYPE_TO_CONST = {
'YtW': SHARP_YtW
}
def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input,
def sht(jobtype, geom_info ginfo, alm_info ainfo, double[:, :, ::1] input,
int spin=0, comm=None, add=False):
cdef void *comm_ptr
cdef int flags = SHARP_DP | (SHARP_ADD if add else 0)
cdef double *palm
cdef double *pmap
cdef int r
cdef sharp_jobtype jobtype_i
cdef double[::1] output_buf
cdef double[:, :, ::1] output_buf
cdef int ntrans = input.shape[0] * input.shape[1]
cdef int i, j
if spin == 0 and input.shape[1] != 1:
raise ValueError('For spin == 0, we need input.shape[1] == 1')
elif spin != 0 and input.shape[1] != 2:
raise ValueError('For spin != 0, we need input.shape[1] == 2')
cdef size_t[::1] ptrbuf = np.empty(2 * ntrans, dtype=np.uintp)
cdef double **alm_ptrs = <double**>&ptrbuf[0]
cdef double **map_ptrs = <double**>&ptrbuf[ntrans]
try:
jobtype_i = JOBTYPE_TO_CONST[jobtype]
@ -71,23 +80,27 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input,
raise ValueError('jobtype must be one of: %s' % ', '.join(sorted(JOBTYPE_TO_CONST.keys())))
if jobtype_i == SHARP_Y or jobtype_i == SHARP_WY:
output = np.empty(ginfo.local_size(), dtype=np.float64)
output = np.empty((input.shape[0], input.shape[1], ginfo.local_size()), dtype=np.float64)
output_buf = output
pmap = &output_buf[0]
palm = &input[0]
for i in range(input.shape[0]):
for j in range(input.shape[1]):
alm_ptrs[i * input.shape[1] + j] = &input[i, j, 0]
map_ptrs[i * input.shape[1] + j] = &output_buf[i, j, 0]
else:
output = np.empty(ainfo.local_size(), dtype=np.float64)
output = np.empty((input.shape[0], input.shape[1], ainfo.local_size()), dtype=np.float64)
output_buf = output
pmap = &input[0]
palm = &output_buf[0]
for i in range(input.shape[0]):
for j in range(input.shape[1]):
alm_ptrs[i * input.shape[1] + j] = &output_buf[i, j, 0]
map_ptrs[i * input.shape[1] + j] = &input[i, j, 0]
if comm is None:
with nogil:
sharp_execute (
jobtype_i,
geom_info=ginfo.ginfo, alm_info=ainfo.ainfo,
spin=spin, alm=&palm, map=&pmap,
ntrans=1, flags=flags, time=NULL, opcnt=NULL)
spin=spin, alm=alm_ptrs, map=map_ptrs,
ntrans=ntrans, flags=flags, time=NULL, opcnt=NULL)
else:
from mpi4py import MPI
if not isinstance(comm, MPI.Comm):
@ -97,8 +110,8 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input,
r = sharp_execute_mpi_maybe (
comm_ptr, jobtype_i,
geom_info=ginfo.ginfo, alm_info=ainfo.ainfo,
spin=spin, alm=&palm, map=&pmap,
ntrans=1, flags=flags, time=NULL, opcnt=NULL)
spin=spin, alm=alm_ptrs, map=map_ptrs,
ntrans=ntrans, flags=flags, time=NULL, opcnt=NULL)
if r == SHARP_ERROR_NO_MPI:
raise Exception('MPI requested, but not available')

View file

@ -25,8 +25,9 @@ def test_basic():
alm[0] = 1
map = libsharp.synthesis(grid, order, alm, comm=MPI.COMM_WORLD)
map = libsharp.synthesis(grid, order, np.repeat(alm[None, None, :], 3, 0), comm=MPI.COMM_WORLD)
assert np.all(map[2, :] == map[1, :]) and np.all(map[1, :] == map[0, :])
map = map[0, 0, :]
if rank == 0:
healpy.mollzoom(map)
from matplotlib.pyplot import show