From 1180e6ac8826cbe82a7f8f5a1c249e853f727e6c Mon Sep 17 00:00:00 2001 From: Guilhem Lavaux Date: Wed, 4 Mar 2015 22:24:37 +0100 Subject: [PATCH] MPI FFTW calls --- src/fourier/fft/fftw_calls.hpp | 2 +- src/fourier/fft/fftw_calls_mpi.hpp | 98 ++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 src/fourier/fft/fftw_calls_mpi.hpp diff --git a/src/fourier/fft/fftw_calls.hpp b/src/fourier/fft/fftw_calls.hpp index f61f94b..b627cf4 100644 --- a/src/fourier/fft/fftw_calls.hpp +++ b/src/fourier/fft/fftw_calls.hpp @@ -113,7 +113,7 @@ public: \ FFTW_CALLS_BASE(double, fftw); FFTW_CALLS_BASE(float, fftwf); - +#undef FFTW_CALLS_BASE }; #endif diff --git a/src/fourier/fft/fftw_calls_mpi.hpp b/src/fourier/fft/fftw_calls_mpi.hpp new file mode 100644 index 0000000..675db5a --- /dev/null +++ b/src/fourier/fft/fftw_calls_mpi.hpp @@ -0,0 +1,98 @@ +#ifndef __MPI_FFTW_UNIFIED_CALLS_HPP +#define __MPI_FFTW_UNIFIED_CALLS_HPP + +#include +#include + +namespace CosmoTool +{ + +static inline void init_fftw_mpi() +{ + fftw_mpi_init(); +} + +static inline void done_fftw_mpi() +{ + fftw_mpi_cleanup(); +} + +template class FFTW_MPI_Calls {}; + + +#define FFTW_MPI_CALLS_BASE(rtype, prefix) \ + template<> \ +class FFTW_MPI_Calls { \ +public: \ + typedef rtype real_type; \ + typedef prefix ## _complex complex_type; \ + typedef prefix ## _plan plan_type; \ + \ + static complex_type *alloc_complex(size_t N) { return prefix ## _alloc_complex(N); } \ + static real_type *alloc_real(size_t N) { return prefix ## _alloc_real(N); } \ + static void free(void *p) { fftw_free(p); } \ +\ + static ptrdiff_t local_size_2d(ptrdiff_t N0, ptrdiff_t N1, MPI_Comm comm, \ + ptrdiff_t *local_n0, ptrdiff_t *local_0_start) { \ + return prefix ## _mpi_local_size_2d(N0, N1, comm, local_n0, local_0_start); \ + } \ +\ + static ptrdiff_t local_size_3d(ptrdiff_t N0, ptrdiff_t N1, ptrdiff_t N2, MPI_Comm comm, \ + ptrdiff_t *local_n0, ptrdiff_t *local_0_start) { \ + return prefix ## _mpi_local_size_3d(N0, N1, N2, comm, local_n0, local_0_start); \ + } \ +\ + static void execute(plan_type p) { prefix ## _execute(p); } \ + static void execute_r2c(plan_type p, real_type *in, complex_type *out) { prefix ## _mpi_execute_dft_r2c(p, in, out); } \ + static void execute_c2r(plan_type p, complex_type *in, real_type *out) { prefix ## _mpi_execute_dft_c2r(p, in, out); } \ +\ + static plan_type plan_dft_r2c_2d(int Nx, int Ny, \ + real_type *in, complex_type *out, \ + MPI_Comm comm, unsigned flags) \ + { \ + return prefix ## _mpi_plan_dft_r2c_2d(Nx, Ny, in, out, \ + comm, flags); \ + } \ + static plan_type plan_dft_c2r_2d(int Nx, int Ny, \ + complex_type *in, real_type *out, \ + MPI_Comm comm, unsigned flags) \ + { \ + return prefix ## _mpi_plan_dft_c2r_2d(Nx, Ny, in, out, \ + comm, flags); \ + } \ + static plan_type plan_dft_r2c_3d(int Nx, int Ny, int Nz, \ + real_type *in, complex_type *out, \ + MPI_Comm comm, unsigned flags) \ + { \ + return prefix ## _mpi_plan_dft_r2c_3d(Nx, Ny, Nz, in, out, comm, flags); \ + } \ + static plan_type plan_dft_c2r_3d(int Nx, int Ny, int Nz, \ + complex_type *in, real_type *out, \ + MPI_Comm comm, \ + unsigned flags) \ + { \ + return prefix ## _mpi_plan_dft_c2r_3d(Nx, Ny, Nz, in, out, comm, flags); \ + } \ +\ + static plan_type plan_dft_r2c(int rank, const ptrdiff_t *n, real_type *in, \ + complex_type *out, MPI_Comm comm, unsigned flags) \ + { \ + return prefix ## _mpi_plan_dft_r2c(rank, n, in, out, comm, flags); \ + } \ + static plan_type plan_dft_c2r(int rank, const ptrdiff_t *n, complex_type *in, \ + real_type *out, MPI_Comm comm, unsigned flags) \ + { \ + return prefix ## _mpi_plan_dft_c2r(rank, n, in, out, comm, flags); \ + } \ + static void destroy_plan(plan_type plan) { prefix ## _destroy_plan(plan); } \ +} + + +FFTW_MPI_CALLS_BASE(double, fftw); +FFTW_MPI_CALLS_BASE(float, fftwf); + +#undef FFTW_MPI_CALLS_BASE + +}; + +#endif