diff --git a/jaxpm/nn.py b/jaxpm/nn.py index bac0f5d..933ea53 100644 --- a/jaxpm/nn.py +++ b/jaxpm/nn.py @@ -35,9 +35,9 @@ class NeuralSplineFourierFilter(hk.Module): self.n_knots = n_knots self.latent_size = latent_size - def __call__(self, k, a): + def __call__(self, x, a): """ - k: array, scale, normalized to fftfreq default + x: array, scale, normalized to fftfreq default a: scalar, scale factor """ @@ -57,4 +57,4 @@ class NeuralSplineFourierFilter(hk.Module): # Augment with repeating points ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))]) - return _deBoorVectorized(jnp.clip(k/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3) \ No newline at end of file + return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3) \ No newline at end of file