From 5195a28582a8e9302d433f6187bee65d84bb0aa1 Mon Sep 17 00:00:00 2001 From: EiffL Date: Sat, 26 Mar 2022 03:03:59 +0100 Subject: [PATCH] fix minor issue --- jaxpm/nn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jaxpm/nn.py b/jaxpm/nn.py index bac0f5d..933ea53 100644 --- a/jaxpm/nn.py +++ b/jaxpm/nn.py @@ -35,9 +35,9 @@ class NeuralSplineFourierFilter(hk.Module): self.n_knots = n_knots self.latent_size = latent_size - def __call__(self, k, a): + def __call__(self, x, a): """ - k: array, scale, normalized to fftfreq default + x: array, scale, normalized to fftfreq default a: scalar, scale factor """ @@ -57,4 +57,4 @@ class NeuralSplineFourierFilter(hk.Module): # Augment with repeating points ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))]) - return _deBoorVectorized(jnp.clip(k/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3) \ No newline at end of file + return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3) \ No newline at end of file