diff --git a/libsharp/sharp_core_inc.c b/libsharp/sharp_core_inc.c index 214e526..747658c 100644 --- a/libsharp/sharp_core_inc.c +++ b/libsharp/sharp_core_inc.c @@ -77,19 +77,19 @@ static void Y(Tbnormalize) (Tb * restrict val, Tb * restrict scale, const Tv vfmin=vload(sharp_fsmall*maxval), vfmax=vload(maxval); for (int i=0;iv[i]),vfmax); + Tm mask = vgt(vabs(val->v[i]),vfmax); while (vanyTrue(mask)) { - vmuleq(val->v[i],vblend(mask,vfsmall,vone)); - vaddeq(scale->v[i],vblend(mask,vone,vzero)); + vmuleq_mask(mask,val->v[i],vfsmall); + vaddeq_mask(mask,scale->v[i],vone); mask = vgt(vabs(val->v[i]),vfmax); } - mask = vand(vlt(vabs(val->v[i]),vfmin),vne(val->v[i],vzero)); + mask = vand_mask(vlt(vabs(val->v[i]),vfmin),vne(val->v[i],vzero)); while (vanyTrue(mask)) { - vmuleq(val->v[i],vblend(mask,vfbig,vone)); - vsubeq(scale->v[i],vblend(mask,vone,vzero)); - mask = vand(vlt(vabs(val->v[i]),vfmin),vne(val->v[i],vzero)); + vmuleq_mask(mask,val->v[i],vfbig); + vsubeq_mask(mask,scale->v[i],vone); + mask = vand_mask(vlt(vabs(val->v[i]),vfmin),vne(val->v[i],vzero)); } } } @@ -131,13 +131,13 @@ static inline int Y(rescale) (Tb * restrict lam1, Tb * restrict lam2, int did_scale=0; for (int i=0;iv[i]),vload(sharp_ftol)); + Tm mask = vgt(vabs(lam2->v[i]),vload(sharp_ftol)); if (vanyTrue(mask)) { did_scale=1; - Tv fact = vblend(mask,vload(sharp_fsmall),vone); - vmuleq(lam1->v[i],fact); vmuleq(lam2->v[i],fact); - vaddeq(scale->v[i],vblend(mask,vone,vzero)); + vmuleq_mask(mask,lam1->v[i],vload(sharp_fsmall)); + vmuleq_mask(mask,lam2->v[i],vload(sharp_fsmall)); + vaddeq_mask(mask,scale->v[i],vone); } } return did_scale; @@ -146,25 +146,25 @@ static inline int Y(rescale) (Tb * restrict lam1, Tb * restrict lam2, static inline int Y(TballLt)(Tb a,double b) { Tv vb=vload(b); - Tv res=vlt(a.v[0],vb); + Tm res=vlt(a.v[0],vb); for (int i=1; i(b))?1.:0.) -#define vge(a,b) (((a)>=(b))?1.:0.) -#define vne(a,b) (((a)!=(b))?1.:0.) -#define vand(a,b) ((((a)*(b))!=0.)?1.:0.) -#define vor(a,b) ((((a)+(b))!=0.)?1.:0.) +#define vlt(a,b) ((a)<(b)) +#define vgt(a,b) ((a)>(b)) +#define vge(a,b) ((a)>=(b)) +#define vne(a,b) ((a)!=(b)) +#define vand_mask(a,b) ((a)&&(b)) static inline Tv vmin (Tv a, Tv b) { return (ab) ? a : b; } -#define vanyTrue(a) ((a)!=0.) -#define vallTrue(a) ((a)!=0.) -#define vblend(m,a,b) (((m)!=0.) ? (a) : (b)) +#define vanyTrue(a) (a) +#define vallTrue(a) (a) #define vzero 0. #define vone 1. @@ -85,13 +87,26 @@ static inline Tv vmax (Tv a, Tv b) { return (a>b) ? a : b; } #endif typedef __m128d Tv; +typedef __m128d Tm; + +#if defined(__SSE4_1__) +#define vblend__(m,a,b) _mm_blendv_pd(b,a,m) +#else +static inline Tv vblend(Tv m, Tv a, Tv b) + { return _mm_or_pd(_mm_and_pd(a,m),_mm_andnot_pd(m,b)); } +#endif +#define vzero _mm_setzero_pd() +#define vone _mm_set1_pd(1.) #define vadd(a,b) _mm_add_pd(a,b) #define vaddeq(a,b) a=_mm_add_pd(a,b) +#define vaddeq_mask(mask,a,b) a=_mm_add_pd(a,vblend__(mask,b,vzero)) #define vsub(a,b) _mm_sub_pd(a,b) #define vsubeq(a,b) a=_mm_sub_pd(a,b) +#define vsubeq_mask(mask,a,b) a=_mm_sub_pd(a,vblend__(mask,b,vzero)) #define vmul(a,b) _mm_mul_pd(a,b) #define vmuleq(a,b) a=_mm_mul_pd(a,b) +#define vmuleq_mask(mask,a,b) a=_mm_mul_pd(a,vblend__(mask,b,vone)) #define vfmaeq(a,b,c) a=_mm_add_pd(a,_mm_mul_pd(b,c)) #define vfmseq(a,b,c) a=_mm_sub_pd(a,_mm_mul_pd(b,c)) #define vfmaaeq(a,b,c,d,e) \ @@ -106,20 +121,11 @@ typedef __m128d Tv; #define vgt(a,b) _mm_cmpgt_pd(a,b) #define vge(a,b) _mm_cmpge_pd(a,b) #define vne(a,b) _mm_cmpneq_pd(a,b) -#define vand(a,b) _mm_and_pd(a,b) -#define vor(a,b) _mm_or_pd(a,b) +#define vand_mask(a,b) _mm_and_pd(a,b) #define vmin(a,b) _mm_min_pd(a,b) #define vmax(a,b) _mm_max_pd(a,b); #define vanyTrue(a) (_mm_movemask_pd(a)!=0) #define vallTrue(a) (_mm_movemask_pd(a)==3) -#if defined(__SSE4_1__) -#define vblend(m,a,b) _mm_blendv_pd(b,a,m) -#else -static inline Tv vblend(Tv m, Tv a, Tv b) - { return _mm_or_pd(_mm_and_pd(a,m),_mm_andnot_pd(m,b)); } -#endif -#define vzero _mm_setzero_pd() -#define vone _mm_set1_pd(1.) #endif @@ -131,13 +137,21 @@ static inline Tv vblend(Tv m, Tv a, Tv b) #endif typedef __m256d Tv; +typedef __m256d Tm; + +#define vblend__(m,a,b) _mm256_blendv_pd(b,a,m) +#define vzero _mm256_setzero_pd() +#define vone _mm256_set1_pd(1.) #define vadd(a,b) _mm256_add_pd(a,b) #define vaddeq(a,b) a=_mm256_add_pd(a,b) +#define vaddeq_mask(mask,a,b) a=_mm256_add_pd(a,vblend__(mask,b,vzero)) #define vsub(a,b) _mm256_sub_pd(a,b) #define vsubeq(a,b) a=_mm256_sub_pd(a,b) +#define vsubeq_mask(mask,a,b) a=_mm256_sub_pd(a,vblend__(mask,b,vzero)) #define vmul(a,b) _mm256_mul_pd(a,b) #define vmuleq(a,b) a=_mm256_mul_pd(a,b) +#define vmuleq_mask(mask,a,b) a=_mm256_mul_pd(a,vblend__(mask,b,vone)) #ifdef __FMA4__ #define vfmaeq(a,b,c) a=_mm256_macc_pd(b,c,a) #define vfmseq(a,b,c) a=_mm256_nmacc_pd(b,c,a) @@ -159,15 +173,51 @@ typedef __m256d Tv; #define vgt(a,b) _mm256_cmp_pd(a,b,_CMP_GT_OQ) #define vge(a,b) _mm256_cmp_pd(a,b,_CMP_GE_OQ) #define vne(a,b) _mm256_cmp_pd(a,b,_CMP_NEQ_OQ) -#define vand(a,b) _mm256_and_pd(a,b) -#define vor(a,b) _mm256_or_pd(a,b) +#define vand_mask(a,b) _mm256_and_pd(a,b) #define vmin(a,b) _mm256_min_pd(a,b) #define vmax(a,b) _mm256_max_pd(a,b) #define vanyTrue(a) (_mm256_movemask_pd(a)!=0) #define vallTrue(a) (_mm256_movemask_pd(a)==15) -#define vblend(m,a,b) _mm256_blendv_pd(b,a,m) -#define vzero _mm256_setzero_pd() -#define vone _mm256_set1_pd(1.) + +#endif + +#if (VLEN==8) + +#include + +typedef __m512d Tv; +typedef __mmask8 Tm; + +#define vadd(a,b) _mm512_add_pd(a,b) +#define vaddeq(a,b) a=_mm512_add_pd(a,b) +#define vaddeq_mask(mask,a,b) a=_mm512_mask_add_pd(a,mask,a,b); +#define vsub(a,b) _mm512_sub_pd(a,b) +#define vsubeq(a,b) a=_mm512_sub_pd(a,b) +#define vsubeq_mask(mask,a,b) a=_mm512_mask_sub_pd(a,mask,a,b); +#define vmul(a,b) _mm512_mul_pd(a,b) +#define vmuleq(a,b) a=_mm512_mul_pd(a,b) +#define vmuleq_mask(mask,a,b) a=_mm512_mask_mul_pd(a,mask,a,b); +#define vfmaeq(a,b,c) a=_mm512_fmadd_pd(b,c,a) +#define vfmseq(a,b,c) a=_mm512_fnmadd_pd(b,c,a) +#define vfmaaeq(a,b,c,d,e) a=_mm512_fmadd_pd(d,e,_mm512_fmadd_pd(b,c,a)) +#define vfmaseq(a,b,c,d,e) a=_mm512_fnmadd_pd(d,e,_mm512_fmadd_pd(b,c,a)) +#define vneg(a) _mm512_xor_pd(_mm512_set1_pd(-0.),a) +#define vload(a) _mm512_set1_pd(a) +#define vabs(a) (__m512d)_mm512_andnot_epi64((__m512i)_mm512_set1_pd(-0.),(__m512i)a) +#define vsqrt(a) _mm512_sqrt_pd(a) +#define vlt(a,b) _mm512_cmp_pd_mask(a,b,_CMP_LT_OQ) +#define vgt(a,b) _mm512_cmp_pd_mask(a,b,_CMP_GT_OQ) +#define vge(a,b) _mm512_cmp_pd_mask(a,b,_CMP_GE_OQ) +#define vne(a,b) _mm512_cmp_pd_mask(a,b,_CMP_NEQ_OQ) +#define vand(a,b) (__m512d)_mm512_and_epi64((__m512i)a,(__m512i)b) +#define vor(a,b) (__m512d)_mm512_or_epi64((__m512i)a,(__m512i)b) +#define vmin(a,b) _mm512_min_pd(a,b) +#define vmax(a,b) _mm512_max_pd(a,b) +#define vanyTrue(a) (a!=0) +#define vallTrue(a) (a==255) + +#define vzero _mm512_setzero_pd() +#define vone _mm512_set1_pd(1.) #endif diff --git a/libsharp/sharp_vecutil.h b/libsharp/sharp_vecutil.h index fb6b60f..16bfa13 100644 --- a/libsharp/sharp_vecutil.h +++ b/libsharp/sharp_vecutil.h @@ -25,14 +25,16 @@ /*! \file sharp_vecutil.h * Functionality related to vector instruction support * - * Copyright (C) 2012 Max-Planck-Society + * Copyright (C) 2012,2013 Max-Planck-Society * \author Martin Reinecke */ #ifndef SHARP_VECUTIL_H #define SHARP_VECUTIL_H -#if (defined (__AVX__)) +#if (defined (__MIC__)) +#define VLEN 8 +#elif (defined (__AVX__)) #define VLEN 4 #elif (defined (__SSE2__)) #define VLEN 2