Support for multiple transforms and spin-transforms
This commit is contained in:
parent
48e213151a
commit
16888967c1
2 changed files with 31 additions and 17 deletions
|
@ -54,16 +54,25 @@ JOBTYPE_TO_CONST = {
|
||||||
'YtW': SHARP_YtW
|
'YtW': SHARP_YtW
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def sht(jobtype, geom_info ginfo, alm_info ainfo, double[:, :, ::1] input,
|
||||||
def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input,
|
|
||||||
int spin=0, comm=None, add=False):
|
int spin=0, comm=None, add=False):
|
||||||
cdef void *comm_ptr
|
cdef void *comm_ptr
|
||||||
cdef int flags = SHARP_DP | (SHARP_ADD if add else 0)
|
cdef int flags = SHARP_DP | (SHARP_ADD if add else 0)
|
||||||
cdef double *palm
|
|
||||||
cdef double *pmap
|
|
||||||
cdef int r
|
cdef int r
|
||||||
cdef sharp_jobtype jobtype_i
|
cdef sharp_jobtype jobtype_i
|
||||||
cdef double[::1] output_buf
|
cdef double[:, :, ::1] output_buf
|
||||||
|
cdef int ntrans = input.shape[0] * input.shape[1]
|
||||||
|
cdef int i, j
|
||||||
|
|
||||||
|
if spin == 0 and input.shape[1] != 1:
|
||||||
|
raise ValueError('For spin == 0, we need input.shape[1] == 1')
|
||||||
|
elif spin != 0 and input.shape[1] != 2:
|
||||||
|
raise ValueError('For spin != 0, we need input.shape[1] == 2')
|
||||||
|
|
||||||
|
|
||||||
|
cdef size_t[::1] ptrbuf = np.empty(2 * ntrans, dtype=np.uintp)
|
||||||
|
cdef double **alm_ptrs = <double**>&ptrbuf[0]
|
||||||
|
cdef double **map_ptrs = <double**>&ptrbuf[ntrans]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
jobtype_i = JOBTYPE_TO_CONST[jobtype]
|
jobtype_i = JOBTYPE_TO_CONST[jobtype]
|
||||||
|
@ -71,23 +80,27 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input,
|
||||||
raise ValueError('jobtype must be one of: %s' % ', '.join(sorted(JOBTYPE_TO_CONST.keys())))
|
raise ValueError('jobtype must be one of: %s' % ', '.join(sorted(JOBTYPE_TO_CONST.keys())))
|
||||||
|
|
||||||
if jobtype_i == SHARP_Y or jobtype_i == SHARP_WY:
|
if jobtype_i == SHARP_Y or jobtype_i == SHARP_WY:
|
||||||
output = np.empty(ginfo.local_size(), dtype=np.float64)
|
output = np.empty((input.shape[0], input.shape[1], ginfo.local_size()), dtype=np.float64)
|
||||||
output_buf = output
|
output_buf = output
|
||||||
pmap = &output_buf[0]
|
for i in range(input.shape[0]):
|
||||||
palm = &input[0]
|
for j in range(input.shape[1]):
|
||||||
|
alm_ptrs[i * input.shape[1] + j] = &input[i, j, 0]
|
||||||
|
map_ptrs[i * input.shape[1] + j] = &output_buf[i, j, 0]
|
||||||
else:
|
else:
|
||||||
output = np.empty(ainfo.local_size(), dtype=np.float64)
|
output = np.empty((input.shape[0], input.shape[1], ainfo.local_size()), dtype=np.float64)
|
||||||
output_buf = output
|
output_buf = output
|
||||||
pmap = &input[0]
|
for i in range(input.shape[0]):
|
||||||
palm = &output_buf[0]
|
for j in range(input.shape[1]):
|
||||||
|
alm_ptrs[i * input.shape[1] + j] = &output_buf[i, j, 0]
|
||||||
|
map_ptrs[i * input.shape[1] + j] = &input[i, j, 0]
|
||||||
|
|
||||||
if comm is None:
|
if comm is None:
|
||||||
with nogil:
|
with nogil:
|
||||||
sharp_execute (
|
sharp_execute (
|
||||||
jobtype_i,
|
jobtype_i,
|
||||||
geom_info=ginfo.ginfo, alm_info=ainfo.ainfo,
|
geom_info=ginfo.ginfo, alm_info=ainfo.ainfo,
|
||||||
spin=spin, alm=&palm, map=&pmap,
|
spin=spin, alm=alm_ptrs, map=map_ptrs,
|
||||||
ntrans=1, flags=flags, time=NULL, opcnt=NULL)
|
ntrans=ntrans, flags=flags, time=NULL, opcnt=NULL)
|
||||||
else:
|
else:
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
if not isinstance(comm, MPI.Comm):
|
if not isinstance(comm, MPI.Comm):
|
||||||
|
@ -97,8 +110,8 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input,
|
||||||
r = sharp_execute_mpi_maybe (
|
r = sharp_execute_mpi_maybe (
|
||||||
comm_ptr, jobtype_i,
|
comm_ptr, jobtype_i,
|
||||||
geom_info=ginfo.ginfo, alm_info=ainfo.ainfo,
|
geom_info=ginfo.ginfo, alm_info=ainfo.ainfo,
|
||||||
spin=spin, alm=&palm, map=&pmap,
|
spin=spin, alm=alm_ptrs, map=map_ptrs,
|
||||||
ntrans=1, flags=flags, time=NULL, opcnt=NULL)
|
ntrans=ntrans, flags=flags, time=NULL, opcnt=NULL)
|
||||||
if r == SHARP_ERROR_NO_MPI:
|
if r == SHARP_ERROR_NO_MPI:
|
||||||
raise Exception('MPI requested, but not available')
|
raise Exception('MPI requested, but not available')
|
||||||
|
|
||||||
|
|
|
@ -25,8 +25,9 @@ def test_basic():
|
||||||
alm[0] = 1
|
alm[0] = 1
|
||||||
|
|
||||||
|
|
||||||
|
map = libsharp.synthesis(grid, order, np.repeat(alm[None, None, :], 3, 0), comm=MPI.COMM_WORLD)
|
||||||
map = libsharp.synthesis(grid, order, alm, comm=MPI.COMM_WORLD)
|
assert np.all(map[2, :] == map[1, :]) and np.all(map[1, :] == map[0, :])
|
||||||
|
map = map[0, 0, :]
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
healpy.mollzoom(map)
|
healpy.mollzoom(map)
|
||||||
from matplotlib.pyplot import show
|
from matplotlib.pyplot import show
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue