mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 01:57:10 +00:00
fix minor issue
This commit is contained in:
parent
3f9dfa504a
commit
10b093f07f
1 changed files with 3 additions and 3 deletions
|
@ -35,9 +35,9 @@ class NeuralSplineFourierFilter(hk.Module):
|
||||||
self.n_knots = n_knots
|
self.n_knots = n_knots
|
||||||
self.latent_size = latent_size
|
self.latent_size = latent_size
|
||||||
|
|
||||||
def __call__(self, k, a):
|
def __call__(self, x, a):
|
||||||
"""
|
"""
|
||||||
k: array, scale, normalized to fftfreq default
|
x: array, scale, normalized to fftfreq default
|
||||||
a: scalar, scale factor
|
a: scalar, scale factor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -57,4 +57,4 @@ class NeuralSplineFourierFilter(hk.Module):
|
||||||
# Augment with repeating points
|
# Augment with repeating points
|
||||||
ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))])
|
ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))])
|
||||||
|
|
||||||
return _deBoorVectorized(jnp.clip(k/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3)
|
return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3)
|
Loading…
Add table
Reference in a new issue