diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 8447f8a..9a0f783 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -2,24 +2,6 @@ import jax.numpy as jnp import numpy as np -def fftk(shape, symmetric=True, finite=False, dtype=np.float32): - """ Return k_vector given a shape (nc, nc, nc) and box_size - """ - k = [] - for d in range(len(shape)): - kd = np.fft.fftfreq(shape[d]) - kd *= 2 * np.pi - kdshape = np.ones(len(shape), dtype='int') - if symmetric and d == len(shape) - 1: - kd = kd[:shape[d] // 2 + 1] - kdshape[d] = len(kd) - kd = kd.reshape(kdshape) - - k.append(kd.astype(dtype)) - del kd, kdshape - return k - - def gradient_kernel(kvec, direction, order=1): """ Computes the gradient kernel in the requested direction