rework interface, put mor stuff into flags

This commit is contained in:
Martin Reinecke 2012-11-09 12:53:14 +01:00
parent 0a1a9e5716
commit 9f46084386
12 changed files with 113 additions and 160 deletions

View file

@ -162,7 +162,7 @@ void sharp_destroy_alm_info (sharp_alm_info *info)
void sharp_make_geom_info (int nrings, const int *nph, const ptrdiff_t *ofs,
const int *stride, const double *phi0, const double *theta,
const double *weight, sharp_geom_info **geom_info)
const double *wgt_a2m, const double *wgt_m2a, sharp_geom_info **geom_info)
{
sharp_geom_info *info = RALLOC(sharp_geom_info,1);
sharp_ringinfo *infos = RALLOC(sharp_ringinfo,nrings);
@ -177,7 +177,8 @@ void sharp_make_geom_info (int nrings, const int *nph, const ptrdiff_t *ofs,
infos[m].theta = theta[m];
infos[m].cth = cos(theta[m]);
infos[m].sth = sin(theta[m]);
infos[m].weight = (weight != NULL) ? weight[m] : 1.;
infos[m].w_a2m = (wgt_a2m != NULL) ? wgt_a2m[m] : 1.;
infos[m].w_m2a = (wgt_m2a != NULL) ? wgt_m2a[m] : 1.;
infos[m].phi0 = phi0[m];
infos[m].ofs = ofs[m];
infos[m].stride = stride[m];
@ -234,7 +235,7 @@ static int sharp_get_mmax (int *mval, int nm)
static void ringhelper_phase2ring (ringhelper *self,
const sharp_ringinfo *info, void *data, int mmax, const dcmplx *phase,
int pstride, sharp_fde fde, int flags)
int pstride, int flags)
{
int nph = info->nph;
int stride = info->stride;
@ -274,30 +275,18 @@ static void ringhelper_phase2ring (ringhelper *self,
}
#endif
real_plan_backward_c (self->plan, (double *)(self->work));
if (flags & SHARP_ALM2MAP_USE_WEIGHTS)
{
if (fde==DOUBLE)
for (int m=0; m<nph; ++m)
((double *)data)[m*stride+info->ofs]+=creal(self->work[m])*info->weight;
else
for (int m=0; m<nph; ++m)
((float *)data)[m*stride+info->ofs] +=
(float)(creal(self->work[m])*info->weight);
}
if (flags&SHARP_DP)
for (int m=0; m<nph; ++m)
((double *)data)[m*stride+info->ofs]+=creal(self->work[m])*info->w_a2m;
else
{
if (fde==DOUBLE)
for (int m=0; m<nph; ++m)
((double *)data)[m*stride+info->ofs] += creal(self->work[m]);
else
for (int m=0; m<nph; ++m)
((float *)data)[m*stride+info->ofs] += (float)creal(self->work[m]);
}
for (int m=0; m<nph; ++m)
((float *)data)[m*stride+info->ofs] +=
(float)(creal(self->work[m])*info->w_a2m);
}
static void ringhelper_ring2phase (ringhelper *self,
const sharp_ringinfo *info, const void *data, int mmax, dcmplx *phase,
int pstride, sharp_fde fde, int flags)
int pstride, int flags)
{
int nph = info->nph;
#if 1
@ -307,24 +296,12 @@ static void ringhelper_ring2phase (ringhelper *self,
#endif
ringhelper_update (self, nph, mmax, -info->phi0);
if (flags & SHARP_MAP2ALM_IGNORE_WEIGHTS)
{
if (fde==DOUBLE)
for (int m=0; m<nph; ++m)
self->work[m] = ((double *)data)[info->ofs+m*info->stride];
else
for (int m=0; m<nph; ++m)
self->work[m] = ((float *)data)[info->ofs+m*info->stride];
}
if (flags&SHARP_DP)
for (int m=0; m<nph; ++m)
self->work[m] = ((double *)data)[info->ofs+m*info->stride]*info->w_m2a;
else
{
if (fde==DOUBLE)
for (int m=0; m<nph; ++m)
self->work[m] = ((double *)data)[info->ofs+m*info->stride]*info->weight;
else
for (int m=0; m<nph; ++m)
self->work[m] = ((float *)data)[info->ofs+m*info->stride]*info->weight;
}
for (int m=0; m<nph; ++m)
self->work[m] = ((float *)data)[info->ofs+m*info->stride]*info->w_m2a;
real_plan_forward_c (self->plan, (double *)self->work);
@ -341,28 +318,28 @@ static void ringhelper_ring2phase (ringhelper *self,
static void ringhelper_pair2phase (ringhelper *self, int mmax,
const sharp_ringpair *pair, const void *data, dcmplx *phase1, dcmplx *phase2,
int pstride, sharp_fde fde, int flags)
int pstride, int flags)
{
ringhelper_ring2phase (self,&(pair->r1),data,mmax,phase1,pstride,fde,flags);
ringhelper_ring2phase (self,&(pair->r1),data,mmax,phase1,pstride,flags);
if (pair->r2.nph>0)
ringhelper_ring2phase (self,&(pair->r2),data,mmax,phase2,pstride,fde,flags);
ringhelper_ring2phase (self,&(pair->r2),data,mmax,phase2,pstride,flags);
}
static void ringhelper_phase2pair (ringhelper *self, int mmax,
const dcmplx *phase1, const dcmplx *phase2, int pstride,
const sharp_ringpair *pair, void *data, sharp_fde fde, int flags)
const sharp_ringpair *pair, void *data, int flags)
{
ringhelper_phase2ring (self,&(pair->r1),data,mmax,phase1,pstride,fde,flags);
ringhelper_phase2ring (self,&(pair->r1),data,mmax,phase1,pstride,flags);
if (pair->r2.nph>0)
ringhelper_phase2ring (self,&(pair->r2),data,mmax,phase2,pstride,fde,flags);
ringhelper_phase2ring (self,&(pair->r2),data,mmax,phase2,pstride,flags);
}
static void fill_map (const sharp_geom_info *ginfo, void *map, double value,
sharp_fde fde)
int flags)
{
for (int j=0;j<ginfo->npairs;++j)
{
if (fde==DOUBLE)
if (flags&SHARP_DP)
{
for (int i=0;i<ginfo->pair[j].r1.nph;++i)
((double *)map)[ginfo->pair[j].r1.ofs+i*ginfo->pair[j].r1.stride]=value;
@ -382,9 +359,9 @@ static void fill_map (const sharp_geom_info *ginfo, void *map, double value,
}
static void fill_alm (const sharp_alm_info *ainfo, void *alm, dcmplx value,
sharp_fde fde)
int flags)
{
if (fde==DOUBLE)
if (flags&SHARP_DP)
for (int mi=0;mi<ainfo->nm;++mi)
for (int l=ainfo->mval[mi];l<=ainfo->lmax;++l)
((dcmplx *)alm)[sharp_alm_index(ainfo,l,mi)] = value;
@ -396,13 +373,13 @@ static void fill_alm (const sharp_alm_info *ainfo, void *alm, dcmplx value,
static void init_output (sharp_job *job)
{
if (job->add_output) return;
if (job->flags&SHARP_ADD) return;
if (job->type == SHARP_MAP2ALM)
for (int i=0; i<job->ntrans*job->nalm; ++i)
fill_alm (job->ainfo,job->alm[i],0.,job->fde);
fill_alm (job->ainfo,job->alm[i],0.,job->flags);
else
for (int i=0; i<job->ntrans*job->nmaps; ++i)
fill_map (job->ginfo,job->map[i],0.,job->fde);
fill_map (job->ginfo,job->map[i],0.,job->flags);
}
static void alloc_phase (sharp_job *job, int nm, int ntheta)
@ -426,8 +403,7 @@ static void map2phase (sharp_job *job, int mmax, int llim, int ulim)
int dim2 = pstride*(ith-llim)*(mmax+1);
for (int i=0; i<job->ntrans*job->nmaps; ++i)
ringhelper_pair2phase(&helper,mmax,&job->ginfo->pair[ith], job->map[i],
&job->phase[dim2+2*i], &job->phase[dim2+2*i+1], pstride, job->fde,
job->flags);
&job->phase[dim2+2*i], &job->phase[dim2+2*i+1], pstride, job->flags);
}
ringhelper_destroy(&helper);
} /* end of parallel region */
@ -447,7 +423,7 @@ static void alm2almtmp (sharp_job *job, int lmax, int mi)
int stride=job->ainfo->stride;
if (job->spin==0)
{
if (job->fde==DOUBLE)
if (job->flags&SHARP_DP)
for (int l=job->ainfo->mval[mi]; l<=lmax; ++l)
for (int i=0; i<job->ntrans*job->nalm; ++i)
job->almtmp[job->ntrans*job->nalm*l+i]
@ -460,7 +436,7 @@ static void alm2almtmp (sharp_job *job, int lmax, int mi)
}
else
{
if (job->fde==DOUBLE)
if (job->flags&SHARP_DP)
for (int l=job->ainfo->mval[mi]; l<=lmax; ++l)
for (int i=0; i<job->ntrans*job->nalm; ++i)
job->almtmp[job->ntrans*job->nalm*l+i]
@ -484,7 +460,7 @@ static void almtmp2alm (sharp_job *job, int lmax, int mi)
int stride=job->ainfo->stride;
if (job->spin==0)
{
if (job->fde==DOUBLE)
if (job->flags&SHARP_DP)
for (int l=job->ainfo->mval[mi]; l<=lmax; ++l)
for (int i=0;i<job->ntrans*job->nalm;++i)
((dcmplx *)job->alm[i])[ofs+l*stride] +=
@ -497,7 +473,7 @@ static void almtmp2alm (sharp_job *job, int lmax, int mi)
}
else
{
if (job->fde==DOUBLE)
if (job->flags&SHARP_DP)
for (int l=job->ainfo->mval[mi]; l<=lmax; ++l)
for (int i=0;i<job->ntrans*job->nalm;++i)
((dcmplx *)job->alm[i])[ofs+l*stride] +=
@ -525,7 +501,7 @@ static void phase2map (sharp_job *job, int mmax, int llim, int ulim)
for (int i=0; i<job->ntrans*job->nmaps; ++i)
ringhelper_phase2pair(&helper,mmax,&job->phase[dim2+2*i],
&job->phase[dim2+2*i+1],pstride,&job->ginfo->pair[ith],job->map[i],
job->fde, job->flags);
job->flags);
}
ringhelper_destroy(&helper);
} /* end of parallel region */
@ -546,7 +522,8 @@ static void sharp_execute_job (sharp_job *job)
init_output (job);
int nchunks, chunksize;
get_chunk_info(job->ginfo->npairs,job->nv*VLEN,&nchunks,&chunksize);
get_chunk_info(job->ginfo->npairs,(job->flags&SHARP_NVMAX)*VLEN,&nchunks,
&chunksize);
alloc_phase (job,mmax+1,chunksize);
/* chunk loop */
@ -615,9 +592,8 @@ static void sharp_execute_job (sharp_job *job)
}
static void sharp_build_job_common (sharp_job *job, sharp_jobtype type,
int spin, int add_output, void *alm, void *map,
const sharp_geom_info *geom_info, const sharp_alm_info *alm_info, int ntrans,
int flags, int nv)
int spin, void *alm, void *map, const sharp_geom_info *geom_info,
const sharp_alm_info *alm_info, int ntrans, int flags)
{
UTIL_ASSERT((ntrans>0)&&(ntrans<=SHARP_MAXTRANS),
"bad number of simultaneous transforms");
@ -628,28 +604,27 @@ static void sharp_build_job_common (sharp_job *job, sharp_jobtype type,
job->type = type;
job->spin = spin;
job->norm_l = NULL;
job->add_output = add_output;
job->nmaps = (type==SHARP_ALM2MAP_DERIV1) ? 2 : ((spin>0) ? 2 : 1);
job->nalm = (type==SHARP_ALM2MAP_DERIV1) ? 1 : ((spin>0) ? 2 : 1);
job->ginfo = geom_info;
job->ainfo = alm_info;
job->nv = (nv==0) ? sharp_nv_oracle (type, spin, ntrans) : nv;
job->flags = flags;
if ((job->flags&SHARP_NVMAX)==0)
job->flags|=sharp_nv_oracle (type, spin, ntrans);
job->time = 0.;
job->opcnt = 0;
job->ntrans = ntrans;
job->alm=alm;
job->map=map;
job->flags = flags;
job->fde=(flags & SHARP_DP) ? DOUBLE : FLOAT;
}
void sharp_execute (sharp_jobtype type, int spin, int add_output, void *alm,
void *map, const sharp_geom_info *geom_info, const sharp_alm_info *alm_info,
int ntrans, int flags, int nv, double *time, unsigned long long *opcnt)
void sharp_execute (sharp_jobtype type, int spin, void *alm, void *map,
const sharp_geom_info *geom_info, const sharp_alm_info *alm_info, int ntrans,
int flags, double *time, unsigned long long *opcnt)
{
sharp_job job;
sharp_build_job_common (&job, type, spin, add_output, alm, map, geom_info,
alm_info, ntrans, flags, nv);
sharp_build_job_common (&job, type, spin, alm, map, geom_info, alm_info,
ntrans, flags);
sharp_execute_job (&job);
if (time!=NULL) *time = job.time;
@ -701,8 +676,8 @@ static int sharp_oracle (sharp_jobtype type, int spin, int ntrans)
int ntries=0;
do
{
sharp_execute(type,spin,0,&alm[0],&map[0],tinfo,alms,ntrans,1,nv,&jtime,
NULL);
sharp_execute(type,spin,&alm[0],&map[0],tinfo,alms,ntrans,nv|SHARP_DP,
&jtime,NULL);
if (jtime<time) { time=jtime; nvbest=nv; }
time_acc+=jtime;