Support for multiple transforms and spin-transforms
This commit is contained in:
parent
48e213151a
commit
16888967c1
2 changed files with 31 additions and 17 deletions
|
@ -54,16 +54,25 @@ JOBTYPE_TO_CONST = {
|
|||
'YtW': SHARP_YtW
|
||||
}
|
||||
|
||||
|
||||
def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input,
|
||||
def sht(jobtype, geom_info ginfo, alm_info ainfo, double[:, :, ::1] input,
|
||||
int spin=0, comm=None, add=False):
|
||||
cdef void *comm_ptr
|
||||
cdef int flags = SHARP_DP | (SHARP_ADD if add else 0)
|
||||
cdef double *palm
|
||||
cdef double *pmap
|
||||
cdef int r
|
||||
cdef sharp_jobtype jobtype_i
|
||||
cdef double[::1] output_buf
|
||||
cdef double[:, :, ::1] output_buf
|
||||
cdef int ntrans = input.shape[0] * input.shape[1]
|
||||
cdef int i, j
|
||||
|
||||
if spin == 0 and input.shape[1] != 1:
|
||||
raise ValueError('For spin == 0, we need input.shape[1] == 1')
|
||||
elif spin != 0 and input.shape[1] != 2:
|
||||
raise ValueError('For spin != 0, we need input.shape[1] == 2')
|
||||
|
||||
|
||||
cdef size_t[::1] ptrbuf = np.empty(2 * ntrans, dtype=np.uintp)
|
||||
cdef double **alm_ptrs = <double**>&ptrbuf[0]
|
||||
cdef double **map_ptrs = <double**>&ptrbuf[ntrans]
|
||||
|
||||
try:
|
||||
jobtype_i = JOBTYPE_TO_CONST[jobtype]
|
||||
|
@ -71,23 +80,27 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input,
|
|||
raise ValueError('jobtype must be one of: %s' % ', '.join(sorted(JOBTYPE_TO_CONST.keys())))
|
||||
|
||||
if jobtype_i == SHARP_Y or jobtype_i == SHARP_WY:
|
||||
output = np.empty(ginfo.local_size(), dtype=np.float64)
|
||||
output = np.empty((input.shape[0], input.shape[1], ginfo.local_size()), dtype=np.float64)
|
||||
output_buf = output
|
||||
pmap = &output_buf[0]
|
||||
palm = &input[0]
|
||||
for i in range(input.shape[0]):
|
||||
for j in range(input.shape[1]):
|
||||
alm_ptrs[i * input.shape[1] + j] = &input[i, j, 0]
|
||||
map_ptrs[i * input.shape[1] + j] = &output_buf[i, j, 0]
|
||||
else:
|
||||
output = np.empty(ainfo.local_size(), dtype=np.float64)
|
||||
output = np.empty((input.shape[0], input.shape[1], ainfo.local_size()), dtype=np.float64)
|
||||
output_buf = output
|
||||
pmap = &input[0]
|
||||
palm = &output_buf[0]
|
||||
for i in range(input.shape[0]):
|
||||
for j in range(input.shape[1]):
|
||||
alm_ptrs[i * input.shape[1] + j] = &output_buf[i, j, 0]
|
||||
map_ptrs[i * input.shape[1] + j] = &input[i, j, 0]
|
||||
|
||||
if comm is None:
|
||||
with nogil:
|
||||
sharp_execute (
|
||||
jobtype_i,
|
||||
geom_info=ginfo.ginfo, alm_info=ainfo.ainfo,
|
||||
spin=spin, alm=&palm, map=&pmap,
|
||||
ntrans=1, flags=flags, time=NULL, opcnt=NULL)
|
||||
spin=spin, alm=alm_ptrs, map=map_ptrs,
|
||||
ntrans=ntrans, flags=flags, time=NULL, opcnt=NULL)
|
||||
else:
|
||||
from mpi4py import MPI
|
||||
if not isinstance(comm, MPI.Comm):
|
||||
|
@ -97,8 +110,8 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input,
|
|||
r = sharp_execute_mpi_maybe (
|
||||
comm_ptr, jobtype_i,
|
||||
geom_info=ginfo.ginfo, alm_info=ainfo.ainfo,
|
||||
spin=spin, alm=&palm, map=&pmap,
|
||||
ntrans=1, flags=flags, time=NULL, opcnt=NULL)
|
||||
spin=spin, alm=alm_ptrs, map=map_ptrs,
|
||||
ntrans=ntrans, flags=flags, time=NULL, opcnt=NULL)
|
||||
if r == SHARP_ERROR_NO_MPI:
|
||||
raise Exception('MPI requested, but not available')
|
||||
|
||||
|
|
|
@ -25,8 +25,9 @@ def test_basic():
|
|||
alm[0] = 1
|
||||
|
||||
|
||||
|
||||
map = libsharp.synthesis(grid, order, alm, comm=MPI.COMM_WORLD)
|
||||
map = libsharp.synthesis(grid, order, np.repeat(alm[None, None, :], 3, 0), comm=MPI.COMM_WORLD)
|
||||
assert np.all(map[2, :] == map[1, :]) and np.all(map[1, :] == map[0, :])
|
||||
map = map[0, 0, :]
|
||||
if rank == 0:
|
||||
healpy.mollzoom(map)
|
||||
from matplotlib.pyplot import show
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue