diff --git a/libsharp/sharp_vecsupport.h b/libsharp/sharp_vecsupport.h index 0fe7540..2c01b1d 100644 --- a/libsharp/sharp_vecsupport.h +++ b/libsharp/sharp_vecsupport.h @@ -122,6 +122,9 @@ static inline Tv vblend(Tv m, Tv a, Tv b) #if (VLEN==4) #include +#ifdef __FMA4__ +#include +#endif typedef __m256d Tv; @@ -131,12 +134,21 @@ typedef __m256d Tv; #define vsubeq(a,b) a=_mm256_sub_pd(a,b) #define vmul(a,b) _mm256_mul_pd(a,b) #define vmuleq(a,b) a=_mm256_mul_pd(a,b) +#ifdef __FMA4__ +#define vfmaeq(a,b,c) a=_mm256_macc_pd(b,c,a) +#define vfmseq(a,b,c) a=_mm256_nmacc_pd(b,c,a) +#define vfmaaeq(a,b,c,d,e) \ + a=_mm256_macc_pd(d,e,_mm256_macc_pd(b,c,a)) +#define vfmaseq(a,b,c,d,e) \ + a=_mm256_nmacc_pd(d,e,_mm256_macc_pd(b,c,a)) +#else #define vfmaeq(a,b,c) a=_mm256_add_pd(a,_mm256_mul_pd(b,c)) #define vfmseq(a,b,c) a=_mm256_sub_pd(a,_mm256_mul_pd(b,c)) #define vfmaaeq(a,b,c,d,e) \ a=_mm256_add_pd(a,_mm256_add_pd(_mm256_mul_pd(b,c),_mm256_mul_pd(d,e))) #define vfmaseq(a,b,c,d,e) \ a=_mm256_add_pd(a,_mm256_sub_pd(_mm256_mul_pd(b,c),_mm256_mul_pd(d,e))) +#endif #define vneg(a) _mm256_xor_pd(_mm256_set1_pd(-0.),a) #define vload(a) _mm256_set1_pd(a) #define vabs(a) _mm256_andnot_pd(_mm256_set1_pd(-0.),a)