do not support multiple simultaneous transforms any more
This commit is contained in:
parent
65f47d10cc
commit
c56747d36e
8 changed files with 167 additions and 333 deletions
|
@ -487,10 +487,10 @@ NOINLINE static void init_output (sharp_job *job)
|
|||
{
|
||||
if (job->flags&SHARP_ADD) return;
|
||||
if (job->type == SHARP_MAP2ALM)
|
||||
for (int i=0; i<job->ntrans*job->nalm; ++i)
|
||||
for (int i=0; i<job->nalm; ++i)
|
||||
clear_alm (job->ainfo,job->alm[i],job->flags);
|
||||
else
|
||||
for (int i=0; i<job->ntrans*job->nmaps; ++i)
|
||||
for (int i=0; i<job->nmaps; ++i)
|
||||
clear_map (job->ginfo,job->map[i],job->flags);
|
||||
}
|
||||
|
||||
|
@ -498,24 +498,24 @@ NOINLINE static void alloc_phase (sharp_job *job, int nm, int ntheta)
|
|||
{
|
||||
if (job->type==SHARP_MAP2ALM)
|
||||
{
|
||||
job->s_m=2*job->ntrans*job->nmaps;
|
||||
job->s_m=2*job->nmaps;
|
||||
if (((job->s_m*16*nm)&1023)==0) nm+=3; // hack to avoid critical strides
|
||||
job->s_th=job->s_m*nm;
|
||||
}
|
||||
else
|
||||
{
|
||||
job->s_th=2*job->ntrans*job->nmaps;
|
||||
job->s_th=2*job->nmaps;
|
||||
if (((job->s_th*16*ntheta)&1023)==0) ntheta+=3; // hack to avoid critical strides
|
||||
job->s_m=job->s_th*ntheta;
|
||||
}
|
||||
job->phase=RALLOC(dcmplx,2*job->ntrans*job->nmaps*nm*ntheta);
|
||||
job->phase=RALLOC(dcmplx,2*job->nmaps*nm*ntheta);
|
||||
}
|
||||
|
||||
static void dealloc_phase (sharp_job *job)
|
||||
{ DEALLOC(job->phase); }
|
||||
|
||||
static void alloc_almtmp (sharp_job *job, int lmax)
|
||||
{ job->almtmp=RALLOC(dcmplx,job->ntrans*job->nalm*(lmax+1)); }
|
||||
{ job->almtmp=RALLOC(dcmplx,job->nalm*(lmax+1)); }
|
||||
|
||||
static void dealloc_almtmp (sharp_job *job)
|
||||
{ DEALLOC(job->almtmp); }
|
||||
|
@ -526,13 +526,13 @@ NOINLINE static void alm2almtmp (sharp_job *job, int lmax, int mi)
|
|||
#define COPY_LOOP(real_t, source_t, expr_of_x) \
|
||||
{ \
|
||||
for (int l=m; l<lmin; ++l) \
|
||||
for (int i=0; i<job->ntrans*job->nalm; ++i) \
|
||||
job->almtmp[job->ntrans*job->nalm*l+i] = 0; \
|
||||
for (int i=0; i<job->nalm; ++i) \
|
||||
job->almtmp[job->nalm*l+i] = 0; \
|
||||
for (int l=lmin; l<=lmax; ++l) \
|
||||
for (int i=0; i<job->ntrans*job->nalm; ++i) \
|
||||
for (int i=0; i<job->nalm; ++i) \
|
||||
{ \
|
||||
source_t x = *(source_t *)(((real_t *)job->alm[i])+ofs+l*stride); \
|
||||
job->almtmp[job->ntrans*job->nalm*l+i] = expr_of_x; \
|
||||
job->almtmp[job->nalm*l+i] = expr_of_x; \
|
||||
} \
|
||||
}
|
||||
|
||||
|
@ -586,8 +586,8 @@ NOINLINE static void alm2almtmp (sharp_job *job, int lmax, int mi)
|
|||
}
|
||||
}
|
||||
else
|
||||
memset (job->almtmp+job->ntrans*job->nalm*job->ainfo->mval[mi], 0,
|
||||
job->ntrans*job->nalm*(lmax+1-job->ainfo->mval[mi])*sizeof(dcmplx));
|
||||
memset (job->almtmp+job->nalm*job->ainfo->mval[mi], 0,
|
||||
job->nalm*(lmax+1-job->ainfo->mval[mi])*sizeof(dcmplx));
|
||||
|
||||
#undef COPY_LOOP
|
||||
}
|
||||
|
@ -597,9 +597,9 @@ NOINLINE static void almtmp2alm (sharp_job *job, int lmax, int mi)
|
|||
|
||||
#define COPY_LOOP(real_t, target_t, expr_of_x) \
|
||||
for (int l=lmin; l<=lmax; ++l) \
|
||||
for (int i=0; i<job->ntrans*job->nalm; ++i) \
|
||||
for (int i=0; i<job->nalm; ++i) \
|
||||
{ \
|
||||
dcmplx x = job->almtmp[job->ntrans*job->nalm*l+i]; \
|
||||
dcmplx x = job->almtmp[job->nalm*l+i]; \
|
||||
*(target_t *)(((real_t *)job->alm[i])+ofs+l*stride) += expr_of_x; \
|
||||
}
|
||||
|
||||
|
@ -660,7 +660,7 @@ NOINLINE static void ringtmp2ring (sharp_job *job, sharp_ringinfo *ri,
|
|||
if (job->flags & SHARP_DP)
|
||||
{
|
||||
double **dmap = (double **)job->map;
|
||||
for (int i=0; i<job->ntrans*job->nmaps; ++i)
|
||||
for (int i=0; i<job->nmaps; ++i)
|
||||
{
|
||||
double *restrict p1=&dmap[i][ri->ofs];
|
||||
const double *restrict p2=&ringtmp[i*rstride+1];
|
||||
|
@ -680,7 +680,7 @@ NOINLINE static void ringtmp2ring (sharp_job *job, sharp_ringinfo *ri,
|
|||
else
|
||||
{
|
||||
float **fmap = (float **)job->map;
|
||||
for (int i=0; i<job->ntrans*job->nmaps; ++i)
|
||||
for (int i=0; i<job->nmaps; ++i)
|
||||
for (int m=0; m<ri->nph; ++m)
|
||||
fmap[i][ri->ofs+m*ri->stride] += (float)ringtmp[i*rstride+m+1];
|
||||
}
|
||||
|
@ -690,7 +690,7 @@ NOINLINE static void ring2ringtmp (sharp_job *job, sharp_ringinfo *ri,
|
|||
double *ringtmp, int rstride)
|
||||
{
|
||||
if (job->flags & SHARP_DP)
|
||||
for (int i=0; i<job->ntrans*job->nmaps; ++i)
|
||||
for (int i=0; i<job->nmaps; ++i)
|
||||
{
|
||||
double *restrict p1=&ringtmp[i*rstride+1],
|
||||
*restrict p2=&(((double *)(job->map[i]))[ri->ofs]);
|
||||
|
@ -701,7 +701,7 @@ NOINLINE static void ring2ringtmp (sharp_job *job, sharp_ringinfo *ri,
|
|||
p1[m] = p2[m*ri->stride];
|
||||
}
|
||||
else
|
||||
for (int i=0; i<job->ntrans*job->nmaps; ++i)
|
||||
for (int i=0; i<job->nmaps; ++i)
|
||||
for (int m=0; m<ri->nph; ++m)
|
||||
ringtmp[i*rstride+m+1] = ((float *)(job->map[i]))[ri->ofs+m*ri->stride];
|
||||
}
|
||||
|
@ -711,7 +711,7 @@ static void ring2phase_direct (sharp_job *job, sharp_ringinfo *ri, int mmax,
|
|||
{
|
||||
if (ri->nph<0)
|
||||
{
|
||||
for (int i=0; i<job->ntrans*job->nmaps; ++i)
|
||||
for (int i=0; i<job->nmaps; ++i)
|
||||
for (int m=0; m<=mmax; ++m)
|
||||
phase[2*i+job->s_m*m]=0.;
|
||||
}
|
||||
|
@ -721,7 +721,7 @@ static void ring2phase_direct (sharp_job *job, sharp_ringinfo *ri, int mmax,
|
|||
double wgt = (job->flags&SHARP_USE_WEIGHTS) ? (ri->nph*ri->weight) : 1.;
|
||||
if (job->flags&SHARP_REAL_HARMONICS)
|
||||
wgt *= sqrt_two;
|
||||
for (int i=0; i<job->ntrans*job->nmaps; ++i)
|
||||
for (int i=0; i<job->nmaps; ++i)
|
||||
for (int m=0; m<=mmax; ++m)
|
||||
phase[2*i+job->s_m*m]= (job->flags & SHARP_DP) ?
|
||||
((dcmplx *)(job->map[i]))[ri->ofs+m*ri->stride]*wgt :
|
||||
|
@ -738,7 +738,7 @@ static void phase2ring_direct (sharp_job *job, sharp_ringinfo *ri, int mmax,
|
|||
double wgt = (job->flags&SHARP_USE_WEIGHTS) ? (ri->nph*ri->weight) : 1.;
|
||||
if (job->flags&SHARP_REAL_HARMONICS)
|
||||
wgt *= sqrt_one_half;
|
||||
for (int i=0; i<job->ntrans*job->nmaps; ++i)
|
||||
for (int i=0; i<job->nmaps; ++i)
|
||||
for (int m=0; m<=mmax; ++m)
|
||||
if (job->flags & SHARP_DP)
|
||||
dmap[i][ri->ofs+m*ri->stride] += wgt*phase[2*i+job->s_m*m];
|
||||
|
@ -769,19 +769,19 @@ NOINLINE static void map2phase (sharp_job *job, int mmax, int llim, int ulim)
|
|||
ringhelper helper;
|
||||
ringhelper_init(&helper);
|
||||
int rstride=job->ginfo->nphmax+2;
|
||||
double *ringtmp=RALLOC(double,job->ntrans*job->nmaps*rstride);
|
||||
double *ringtmp=RALLOC(double,job->nmaps*rstride);
|
||||
#pragma omp for schedule(dynamic,1)
|
||||
for (int ith=llim; ith<ulim; ++ith)
|
||||
{
|
||||
int dim2 = job->s_th*(ith-llim);
|
||||
ring2ringtmp(job,&(job->ginfo->pair[ith].r1),ringtmp,rstride);
|
||||
for (int i=0; i<job->ntrans*job->nmaps; ++i)
|
||||
for (int i=0; i<job->nmaps; ++i)
|
||||
ringhelper_ring2phase (&helper,&(job->ginfo->pair[ith].r1),
|
||||
&ringtmp[i*rstride],mmax,&job->phase[dim2+2*i],pstride,job->flags);
|
||||
if (job->ginfo->pair[ith].r2.nph>0)
|
||||
{
|
||||
ring2ringtmp(job,&(job->ginfo->pair[ith].r2),ringtmp,rstride);
|
||||
for (int i=0; i<job->ntrans*job->nmaps; ++i)
|
||||
for (int i=0; i<job->nmaps; ++i)
|
||||
ringhelper_ring2phase (&helper,&(job->ginfo->pair[ith].r2),
|
||||
&ringtmp[i*rstride],mmax,&job->phase[dim2+2*i+1],pstride,job->flags);
|
||||
}
|
||||
|
@ -814,18 +814,18 @@ NOINLINE static void phase2map (sharp_job *job, int mmax, int llim, int ulim)
|
|||
ringhelper helper;
|
||||
ringhelper_init(&helper);
|
||||
int rstride=job->ginfo->nphmax+2;
|
||||
double *ringtmp=RALLOC(double,job->ntrans*job->nmaps*rstride);
|
||||
double *ringtmp=RALLOC(double,job->nmaps*rstride);
|
||||
#pragma omp for schedule(dynamic,1)
|
||||
for (int ith=llim; ith<ulim; ++ith)
|
||||
{
|
||||
int dim2 = job->s_th*(ith-llim);
|
||||
for (int i=0; i<job->ntrans*job->nmaps; ++i)
|
||||
for (int i=0; i<job->nmaps; ++i)
|
||||
ringhelper_phase2ring (&helper,&(job->ginfo->pair[ith].r1),
|
||||
&ringtmp[i*rstride],mmax,&job->phase[dim2+2*i],pstride,job->flags);
|
||||
ringtmp2ring(job,&(job->ginfo->pair[ith].r1),ringtmp,rstride);
|
||||
if (job->ginfo->pair[ith].r2.nph>0)
|
||||
{
|
||||
for (int i=0; i<job->ntrans*job->nmaps; ++i)
|
||||
for (int i=0; i<job->nmaps; ++i)
|
||||
ringhelper_phase2ring (&helper,&(job->ginfo->pair[ith].r2),
|
||||
&ringtmp[i*rstride],mmax,&job->phase[dim2+2*i+1],pstride,job->flags);
|
||||
ringtmp2ring(job,&(job->ginfo->pair[ith].r2),ringtmp,rstride);
|
||||
|
@ -918,10 +918,8 @@ NOINLINE static void sharp_execute_job (sharp_job *job)
|
|||
|
||||
static void sharp_build_job_common (sharp_job *job, sharp_jobtype type,
|
||||
int spin, void *alm, void *map, const sharp_geom_info *geom_info,
|
||||
const sharp_alm_info *alm_info, int ntrans, int flags)
|
||||
const sharp_alm_info *alm_info, int flags)
|
||||
{
|
||||
UTIL_ASSERT((ntrans>0)&&(ntrans<=SHARP_MAXTRANS),
|
||||
"bad number of simultaneous transforms");
|
||||
if (type==SHARP_ALM2MAP_DERIV1) spin=1;
|
||||
if (type==SHARP_MAP2ALM) flags|=SHARP_USE_WEIGHTS;
|
||||
if (type==SHARP_Yt) type=SHARP_MAP2ALM;
|
||||
|
@ -937,23 +935,22 @@ static void sharp_build_job_common (sharp_job *job, sharp_jobtype type,
|
|||
job->ainfo = alm_info;
|
||||
job->flags = flags;
|
||||
if ((job->flags&SHARP_NVMAX)==0)
|
||||
job->flags|=sharp_nv_oracle (type, spin, ntrans);
|
||||
job->flags|=sharp_nv_oracle (type, spin);
|
||||
if (alm_info->flags&SHARP_REAL_HARMONICS)
|
||||
job->flags|=SHARP_REAL_HARMONICS;
|
||||
job->time = 0.;
|
||||
job->opcnt = 0;
|
||||
job->ntrans = ntrans;
|
||||
job->alm=alm;
|
||||
job->map=map;
|
||||
}
|
||||
|
||||
void sharp_execute (sharp_jobtype type, int spin, void *alm, void *map,
|
||||
const sharp_geom_info *geom_info, const sharp_alm_info *alm_info, int ntrans,
|
||||
const sharp_geom_info *geom_info, const sharp_alm_info *alm_info,
|
||||
int flags, double *time, unsigned long long *opcnt)
|
||||
{
|
||||
sharp_job job;
|
||||
sharp_build_job_common (&job, type, spin, alm, map, geom_info, alm_info,
|
||||
ntrans, flags);
|
||||
flags);
|
||||
|
||||
sharp_execute_job (&job);
|
||||
if (time!=NULL) *time = job.time;
|
||||
|
@ -968,7 +965,7 @@ void sharp_set_nchunks_max(int new_nchunks_max)
|
|||
int sharp_get_nv_max (void)
|
||||
{ return 6; }
|
||||
|
||||
static int sharp_oracle (sharp_jobtype type, int spin, int ntrans)
|
||||
static int sharp_oracle (sharp_jobtype type, int spin)
|
||||
{
|
||||
int lmax=511;
|
||||
int mmax=(lmax+1)/2;
|
||||
|
@ -982,7 +979,7 @@ static int sharp_oracle (sharp_jobtype type, int spin, int ntrans)
|
|||
sharp_make_gauss_geom_info (nrings, ppring, 0., 1, ppring, &tinfo);
|
||||
|
||||
ptrdiff_t nalms = ((mmax+1)*(mmax+2))/2 + (mmax+1)*(lmax-mmax);
|
||||
int ncomp = ntrans*((spin==0) ? 1 : 2);
|
||||
int ncomp = (spin==0) ? 1 : 2;
|
||||
|
||||
double **map;
|
||||
ALLOC2D(map,double,ncomp,npix);
|
||||
|
@ -1005,7 +1002,7 @@ static int sharp_oracle (sharp_jobtype type, int spin, int ntrans)
|
|||
int ntries=0;
|
||||
do
|
||||
{
|
||||
sharp_execute(type,spin,&alm[0],&map[0],tinfo,alms,ntrans,
|
||||
sharp_execute(type,spin,&alm[0],&map[0],tinfo,alms,
|
||||
nv|SHARP_DP|SHARP_NO_OPENMP,&jtime,NULL);
|
||||
|
||||
if (jtime<time) { time=jtime; nvbest=nv; }
|
||||
|
@ -1023,26 +1020,18 @@ static int sharp_oracle (sharp_jobtype type, int spin, int ntrans)
|
|||
return nvbest;
|
||||
}
|
||||
|
||||
int sharp_nv_oracle (sharp_jobtype type, int spin, int ntrans)
|
||||
int sharp_nv_oracle (sharp_jobtype type, int spin)
|
||||
{
|
||||
static const int maxtr = 6;
|
||||
static int nv_opt[6][2][5] = {
|
||||
{{0,0,0,0,0},{0,0,0,0,0}},
|
||||
{{0,0,0,0,0},{0,0,0,0,0}},
|
||||
{{0,0,0,0,0},{0,0,0,0,0}},
|
||||
{{0,0,0,0,0},{0,0,0,0,0}},
|
||||
{{0,0,0,0,0},{0,0,0,0,0}},
|
||||
{{0,0,0,0,0},{0,0,0,0,0}} };
|
||||
static int nv_opt[2][5] = {{0,0,0,0,0},{0,0,0,0,0}};
|
||||
|
||||
if (type==SHARP_ALM2MAP_DERIV1) spin=1;
|
||||
UTIL_ASSERT(type<5,"bad type");
|
||||
UTIL_ASSERT((ntrans>0),"bad number of simultaneous transforms");
|
||||
UTIL_ASSERT(spin>=0, "bad spin");
|
||||
ntrans=IMIN(ntrans,maxtr);
|
||||
|
||||
if (nv_opt[ntrans-1][spin!=0][type]==0)
|
||||
nv_opt[ntrans-1][spin!=0][type]=sharp_oracle(type,spin,ntrans);
|
||||
return nv_opt[ntrans-1][spin!=0][type];
|
||||
if (nv_opt[spin!=0][type]==0)
|
||||
nv_opt[spin!=0][type]=sharp_oracle(type,spin);
|
||||
return nv_opt[spin!=0][type];
|
||||
}
|
||||
|
||||
#ifdef USE_MPI
|
||||
|
@ -1050,11 +1039,11 @@ int sharp_nv_oracle (sharp_jobtype type, int spin, int ntrans)
|
|||
|
||||
int sharp_execute_mpi_maybe (void *pcomm, sharp_jobtype type, int spin,
|
||||
void *alm, void *map, const sharp_geom_info *geom_info,
|
||||
const sharp_alm_info *alm_info, int ntrans, int flags, double *time,
|
||||
const sharp_alm_info *alm_info, int flags, double *time,
|
||||
unsigned long long *opcnt)
|
||||
{
|
||||
MPI_Comm comm = *(MPI_Comm*)pcomm;
|
||||
sharp_execute_mpi((MPI_Comm)comm, type, spin, alm, map, geom_info, alm_info, ntrans,
|
||||
sharp_execute_mpi((MPI_Comm)comm, type, spin, alm, map, geom_info, alm_info,
|
||||
flags, time, opcnt);
|
||||
return 0;
|
||||
}
|
||||
|
@ -1063,12 +1052,12 @@ int sharp_execute_mpi_maybe (void *pcomm, sharp_jobtype type, int spin,
|
|||
|
||||
int sharp_execute_mpi_maybe (void *pcomm, sharp_jobtype type, int spin,
|
||||
void *alm, void *map, const sharp_geom_info *geom_info,
|
||||
const sharp_alm_info *alm_info, int ntrans, int flags, double *time,
|
||||
const sharp_alm_info *alm_info, int flags, double *time,
|
||||
unsigned long long *opcnt)
|
||||
{
|
||||
/* Suppress unused warning: */
|
||||
(void)pcomm; (void)type; (void)spin; (void)alm; (void)map; (void)geom_info;
|
||||
(void)alm_info; (void)ntrans; (void)flags; (void)time; (void)opcnt;
|
||||
(void)alm_info; (void)flags; (void)time; (void)opcnt;
|
||||
return SHARP_ERROR_NO_MPI;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue