forked from guilhem_lavaux/JaxPM
fix minor issue
This commit is contained in:
parent
3f9dfa504a
commit
10b093f07f
1 changed files with 3 additions and 3 deletions
|
@ -35,9 +35,9 @@ class NeuralSplineFourierFilter(hk.Module):
|
|||
self.n_knots = n_knots
|
||||
self.latent_size = latent_size
|
||||
|
||||
def __call__(self, k, a):
|
||||
def __call__(self, x, a):
|
||||
"""
|
||||
k: array, scale, normalized to fftfreq default
|
||||
x: array, scale, normalized to fftfreq default
|
||||
a: scalar, scale factor
|
||||
"""
|
||||
|
||||
|
@ -57,4 +57,4 @@ class NeuralSplineFourierFilter(hk.Module):
|
|||
# Augment with repeating points
|
||||
ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))])
|
||||
|
||||
return _deBoorVectorized(jnp.clip(k/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3)
|
||||
return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3)
|
Loading…
Add table
Reference in a new issue