From 907dc42bdff569803842b9a207febe7c00646f2f Mon Sep 17 00:00:00 2001 From: EiffL Date: Sat, 26 Mar 2022 02:59:39 +0100 Subject: [PATCH] adding neural network --- jaxpm/nn.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 jaxpm/nn.py diff --git a/jaxpm/nn.py b/jaxpm/nn.py new file mode 100644 index 0000000..bac0f5d --- /dev/null +++ b/jaxpm/nn.py @@ -0,0 +1,60 @@ +import jax +import jax.numpy as jnp +import haiku as hk + +def _deBoorVectorized(x, t, c, p): + """ + Evaluates S(x). + + Args + ---- + x: position + t: array of knot positions, needs to be padded as described above + c: array of control points + p: degree of B-spline + """ + k = jnp.digitize(x, t) -1 + + d = [c[j + k - p] for j in range(0, p+1)] + for r in range(1, p+1): + for j in range(p, r-1, -1): + alpha = (x - t[j+k-p]) / (t[j+1+k-r] - t[j+k-p]) + d[j] = (1.0 - alpha) * d[j-1] + alpha * d[j] + return d[p] + + +class NeuralSplineFourierFilter(hk.Module): + """A rotationally invariant filter parameterized by + a b-spline with parameters specified by a small NN.""" + + def __init__(self, n_knots=8, latent_size=16, name=None): + """ + n_knots: number of control points for the spline + """ + super().__init__(name=name) + self.n_knots = n_knots + self.latent_size = latent_size + + def __call__(self, k, a): + """ + k: array, scale, normalized to fftfreq default + a: scalar, scale factor + """ + + net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a))) + net = jnp.sin(hk.Linear(self.latent_size)(net)) + + w = hk.Linear(self.n_knots+1)(net) + k = hk.Linear(self.n_knots-1)(net) + + # make sure the knots sum to 1 and are in the interval 0,1 + k = jnp.concatenate([jnp.zeros((1,)), + jnp.cumsum(jax.nn.softmax(k))]) + + w = jnp.concatenate([jnp.zeros((1,)), + w]) + + # Augment with repeating points + ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))]) + + return _deBoorVectorized(jnp.clip(k/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3) \ No newline at end of file