diff --git a/libsharp/sharp_core_inc.c b/libsharp/sharp_core_inc.c index 22e8cad..214e526 100644 --- a/libsharp/sharp_core_inc.c +++ b/libsharp/sharp_core_inc.c @@ -220,7 +220,7 @@ static inline void Y(rec_step) (Tb * restrict rxp, Tb * restrict rxm, } } -static void Y(iter_to_ieee_spin) (const Tb cth, int *l_, +static void Y(iter_to_ieee_spin) (const Tb cth, const Tb sth, int *l_, Tb * rec1p_, Tb * rec1m_, Tb * rec2p_, Tb * rec2m_, Tb * scalep_, Tb * scalem_, const sharp_Ylmgen_C * restrict gen) { @@ -232,6 +232,11 @@ static void Y(iter_to_ieee_spin) (const Tb cth, int *l_, cth2.v[i]=vmax(cth2.v[i],vload(1e-15)); sth2.v[i]=vsqrt(vmul(vsub(vone,cth.v[i]),vload(0.5))); sth2.v[i]=vmax(sth2.v[i],vload(1e-15)); + Tv mask=vlt(sth.v[i],vzero); + Tv cfct=vblend(vand(mask,vlt(cth.v[i],vzero)),vload(-1.),vone); + cth2.v[i]=vmul(cth2.v[i],cfct); + Tv sfct=vblend(vand(mask,vgt(cth.v[i],vzero)),vload(-1.),vone); + sth2.v[i]=vmul(sth2.v[i],sfct); } Tb ccp, ccps, ssp, ssps, csp, csps, scp, scps; diff --git a/libsharp/sharp_core_inc2.c b/libsharp/sharp_core_inc2.c index 4e81def..5c9b4ab 100644 --- a/libsharp/sharp_core_inc2.c +++ b/libsharp/sharp_core_inc2.c @@ -429,12 +429,14 @@ static void Z(map2alm_spin_kernel) (Tb cth, const Y(Tbqu) * restrict p1, Z(saddstep2)(p1, p2, &rec2p, &rec2m, &alm[2*njobs*l] NJ2); } -static void Z(calc_alm2map_spin) (const Tb cth, const sharp_Ylmgen_C *gen, - sharp_job *job, Y(Tbqu) * restrict p1, Y(Tbqu) * restrict p2 NJ1) +static void Z(calc_alm2map_spin) (const Tb cth, const Tb sth, + const sharp_Ylmgen_C *gen, sharp_job *job, Y(Tbqu) * restrict p1, + Y(Tbqu) * restrict p2 NJ1) { int l, lmax=gen->lmax; Tb rec1p, rec1m, rec2p, rec2m, scalem, scalep; - Y(iter_to_ieee_spin) (cth,&l,&rec1p,&rec1m,&rec2p,&rec2m,&scalep,&scalem,gen); + Y(iter_to_ieee_spin) + (cth,sth,&l,&rec1p,&rec1m,&rec2p,&rec2m,&scalep,&scalem,gen); job->opcnt += (l-gen->m) * 10*VLEN*nvec; if (l>lmax) return; job->opcnt += (lmax+1-l) * (12+16*njobs)*VLEN*nvec; @@ -473,12 +475,14 @@ static void Z(calc_alm2map_spin) (const Tb cth, const sharp_Ylmgen_C *gen, lmax NJ2); } -static void Z(calc_map2alm_spin) (Tb cth, const sharp_Ylmgen_C * restrict gen, - sharp_job *job, const Y(Tbqu) * restrict p1, const Y(Tbqu) * restrict p2 NJ1) +static void Z(calc_map2alm_spin) (Tb cth, Tb sth, + const sharp_Ylmgen_C * restrict gen, sharp_job *job, + const Y(Tbqu) * restrict p1, const Y(Tbqu) * restrict p2 NJ1) { int l, lmax=gen->lmax; Tb rec1p, rec1m, rec2p, rec2m, scalem, scalep; - Y(iter_to_ieee_spin) (cth,&l,&rec1p,&rec1m,&rec2p,&rec2m,&scalep,&scalem,gen); + Y(iter_to_ieee_spin) + (cth,sth,&l,&rec1p,&rec1m,&rec2p,&rec2m,&scalep,&scalem,gen); job->opcnt += (l-gen->m) * 10*VLEN*nvec; if (l>lmax) return; job->opcnt += (lmax+1-l) * (12+16*njobs)*VLEN*nvec; @@ -568,12 +572,14 @@ static void Z(alm2map_deriv1_kernel) (Tb cth, Y(Tbqu) * restrict p1, Z(saddstep_d)(p1, p2, rec2p, rec2m, &alm[njobs*l] NJ2); } -static void Z(calc_alm2map_deriv1) (const Tb cth, const sharp_Ylmgen_C *gen, - sharp_job *job, Y(Tbqu) * restrict p1, Y(Tbqu) * restrict p2 NJ1) +static void Z(calc_alm2map_deriv1) (const Tb cth, const Tb sth, + const sharp_Ylmgen_C *gen, sharp_job *job, Y(Tbqu) * restrict p1, + Y(Tbqu) * restrict p2 NJ1) { int l, lmax=gen->lmax; Tb rec1p, rec1m, rec2p, rec2m, scalem, scalep; - Y(iter_to_ieee_spin) (cth,&l,&rec1p,&rec1m,&rec2p,&rec2m,&scalep,&scalem,gen); + Y(iter_to_ieee_spin) + (cth,sth,&l,&rec1p,&rec1m,&rec2p,&rec2m,&scalep,&scalem,gen); job->opcnt += (l-gen->m) * 10*VLEN*nvec; if (l>lmax) return; job->opcnt += (lmax+1-l) * (12+8*njobs)*VLEN*nvec; @@ -669,7 +675,7 @@ static void Z(inner_loop) (sharp_job *job, const int *ispair, for (int ith=0; ith=ulim-llim) itot=ulim-llim-1; if (mlim[itot]>=m) skip=0; - cth.s[i]=cth_[itot]; + cth.s[i]=cth_[itot]; sth.s[i]=sth_[itot]; } if (!skip) (job->type==SHARP_ALM2MAP) ? Z(calc_alm2map_spin ) - (cth.b,gen,job,&p1[0].b,&p2[0].b NJ2) : + (cth.b,sth.b,gen,job,&p1[0].b,&p2[0].b NJ2) : Z(calc_alm2map_deriv1) - (cth.b,gen,job,&p1[0].b,&p2[0].b NJ2); + (cth.b,sth.b,gen,job,&p1[0].b,&p2[0].b NJ2); for (int i=0; i=ulim-llim) itot=ulim-llim-1; if (mlim[itot]>=m) skip=0; - cth.s[i]=cth_[itot]; + cth.s[i]=cth_[itot]; sth.s[i]=sth_[itot]; if (i+ith