mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
adding neural network
This commit is contained in:
parent
3e1b3d8a3b
commit
3f9dfa504a
1 changed files with 60 additions and 0 deletions
60
jaxpm/nn.py
Normal file
60
jaxpm/nn.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import haiku as hk
|
||||||
|
|
||||||
|
def _deBoorVectorized(x, t, c, p):
|
||||||
|
"""
|
||||||
|
Evaluates S(x).
|
||||||
|
|
||||||
|
Args
|
||||||
|
----
|
||||||
|
x: position
|
||||||
|
t: array of knot positions, needs to be padded as described above
|
||||||
|
c: array of control points
|
||||||
|
p: degree of B-spline
|
||||||
|
"""
|
||||||
|
k = jnp.digitize(x, t) -1
|
||||||
|
|
||||||
|
d = [c[j + k - p] for j in range(0, p+1)]
|
||||||
|
for r in range(1, p+1):
|
||||||
|
for j in range(p, r-1, -1):
|
||||||
|
alpha = (x - t[j+k-p]) / (t[j+1+k-r] - t[j+k-p])
|
||||||
|
d[j] = (1.0 - alpha) * d[j-1] + alpha * d[j]
|
||||||
|
return d[p]
|
||||||
|
|
||||||
|
|
||||||
|
class NeuralSplineFourierFilter(hk.Module):
|
||||||
|
"""A rotationally invariant filter parameterized by
|
||||||
|
a b-spline with parameters specified by a small NN."""
|
||||||
|
|
||||||
|
def __init__(self, n_knots=8, latent_size=16, name=None):
|
||||||
|
"""
|
||||||
|
n_knots: number of control points for the spline
|
||||||
|
"""
|
||||||
|
super().__init__(name=name)
|
||||||
|
self.n_knots = n_knots
|
||||||
|
self.latent_size = latent_size
|
||||||
|
|
||||||
|
def __call__(self, k, a):
|
||||||
|
"""
|
||||||
|
k: array, scale, normalized to fftfreq default
|
||||||
|
a: scalar, scale factor
|
||||||
|
"""
|
||||||
|
|
||||||
|
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
|
||||||
|
net = jnp.sin(hk.Linear(self.latent_size)(net))
|
||||||
|
|
||||||
|
w = hk.Linear(self.n_knots+1)(net)
|
||||||
|
k = hk.Linear(self.n_knots-1)(net)
|
||||||
|
|
||||||
|
# make sure the knots sum to 1 and are in the interval 0,1
|
||||||
|
k = jnp.concatenate([jnp.zeros((1,)),
|
||||||
|
jnp.cumsum(jax.nn.softmax(k))])
|
||||||
|
|
||||||
|
w = jnp.concatenate([jnp.zeros((1,)),
|
||||||
|
w])
|
||||||
|
|
||||||
|
# Augment with repeating points
|
||||||
|
ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))])
|
||||||
|
|
||||||
|
return _deBoorVectorized(jnp.clip(k/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3)
|
Loading…
Add table
Reference in a new issue