JaxPM/jaxpm/nn.py

61 lines
1.7 KiB
Python
Raw Normal View History

2024-07-09 14:54:34 -04:00
import haiku as hk
2022-03-26 02:59:39 +01:00
import jax
import jax.numpy as jnp
2024-07-09 14:54:34 -04:00
2022-03-26 02:59:39 +01:00
def _deBoorVectorized(x, t, c, p):
"""
Evaluates S(x).
Args
----
x: position
t: array of knot positions, needs to be padded as described above
c: array of control points
p: degree of B-spline
"""
2024-07-09 14:54:34 -04:00
k = jnp.digitize(x, t) - 1
d = [c[j + k - p] for j in range(0, p + 1)]
for r in range(1, p + 1):
for j in range(p, r - 1, -1):
alpha = (x - t[j + k - p]) / (t[j + 1 + k - r] - t[j + k - p])
d[j] = (1.0 - alpha) * d[j - 1] + alpha * d[j]
2022-03-26 02:59:39 +01:00
return d[p]
class NeuralSplineFourierFilter(hk.Module):
2024-07-09 14:54:34 -04:00
"""A rotationally invariant filter parameterized by
2022-03-26 02:59:39 +01:00
a b-spline with parameters specified by a small NN."""
2024-07-09 14:54:34 -04:00
def __init__(self, n_knots=8, latent_size=16, name=None):
"""
n_knots: number of control points for the spline
2022-03-26 02:59:39 +01:00
"""
2024-07-09 14:54:34 -04:00
super().__init__(name=name)
self.n_knots = n_knots
self.latent_size = latent_size
2022-03-26 02:59:39 +01:00
2024-07-09 14:54:34 -04:00
def __call__(self, x, a):
"""
2022-03-26 03:03:59 +01:00
x: array, scale, normalized to fftfreq default
2022-03-26 02:59:39 +01:00
a: scalar, scale factor
"""
2024-07-09 14:54:34 -04:00
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
net = jnp.sin(hk.Linear(self.latent_size)(net))
w = hk.Linear(self.n_knots + 1)(net)
k = hk.Linear(self.n_knots - 1)(net)
2022-03-26 02:59:39 +01:00
2024-07-09 14:54:34 -04:00
# make sure the knots sum to 1 and are in the interval 0,1
k = jnp.concatenate([jnp.zeros((1, )), jnp.cumsum(jax.nn.softmax(k))])
2022-03-26 02:59:39 +01:00
2024-07-09 14:54:34 -04:00
w = jnp.concatenate([jnp.zeros((1, )), w])
2022-03-26 02:59:39 +01:00
2024-07-09 14:54:34 -04:00
# Augment with repeating points
ak = jnp.concatenate([jnp.zeros((3, )), k, jnp.ones((3, ))])
2022-03-26 02:59:39 +01:00
2024-07-09 14:54:34 -04:00
return _deBoorVectorized(jnp.clip(x / jnp.sqrt(3), 0, 1 - 1e-4), ak, w,
3)