mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-14 03:51:11 +00:00
Applying formatting
This commit is contained in:
parent
835fa89aec
commit
f28442bb48
14 changed files with 565 additions and 445 deletions
60
jaxpm/nn.py
60
jaxpm/nn.py
|
@ -1,6 +1,7 @@
|
|||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import haiku as hk
|
||||
|
||||
|
||||
def _deBoorVectorized(x, t, c, p):
|
||||
"""
|
||||
|
@ -13,48 +14,47 @@ def _deBoorVectorized(x, t, c, p):
|
|||
c: array of control points
|
||||
p: degree of B-spline
|
||||
"""
|
||||
k = jnp.digitize(x, t) -1
|
||||
|
||||
d = [c[j + k - p] for j in range(0, p+1)]
|
||||
for r in range(1, p+1):
|
||||
for j in range(p, r-1, -1):
|
||||
alpha = (x - t[j+k-p]) / (t[j+1+k-r] - t[j+k-p])
|
||||
d[j] = (1.0 - alpha) * d[j-1] + alpha * d[j]
|
||||
k = jnp.digitize(x, t) - 1
|
||||
|
||||
d = [c[j + k - p] for j in range(0, p + 1)]
|
||||
for r in range(1, p + 1):
|
||||
for j in range(p, r - 1, -1):
|
||||
alpha = (x - t[j + k - p]) / (t[j + 1 + k - r] - t[j + k - p])
|
||||
d[j] = (1.0 - alpha) * d[j - 1] + alpha * d[j]
|
||||
return d[p]
|
||||
|
||||
|
||||
class NeuralSplineFourierFilter(hk.Module):
|
||||
"""A rotationally invariant filter parameterized by
|
||||
"""A rotationally invariant filter parameterized by
|
||||
a b-spline with parameters specified by a small NN."""
|
||||
|
||||
def __init__(self, n_knots=8, latent_size=16, name=None):
|
||||
def __init__(self, n_knots=8, latent_size=16, name=None):
|
||||
"""
|
||||
n_knots: number of control points for the spline
|
||||
"""
|
||||
n_knots: number of control points for the spline
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
self.n_knots = n_knots
|
||||
self.latent_size = latent_size
|
||||
super().__init__(name=name)
|
||||
self.n_knots = n_knots
|
||||
self.latent_size = latent_size
|
||||
|
||||
def __call__(self, x, a):
|
||||
"""
|
||||
def __call__(self, x, a):
|
||||
"""
|
||||
x: array, scale, normalized to fftfreq default
|
||||
a: scalar, scale factor
|
||||
"""
|
||||
|
||||
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
|
||||
net = jnp.sin(hk.Linear(self.latent_size)(net))
|
||||
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
|
||||
net = jnp.sin(hk.Linear(self.latent_size)(net))
|
||||
|
||||
w = hk.Linear(self.n_knots+1)(net)
|
||||
k = hk.Linear(self.n_knots-1)(net)
|
||||
|
||||
# make sure the knots sum to 1 and are in the interval 0,1
|
||||
k = jnp.concatenate([jnp.zeros((1,)),
|
||||
jnp.cumsum(jax.nn.softmax(k))])
|
||||
w = hk.Linear(self.n_knots + 1)(net)
|
||||
k = hk.Linear(self.n_knots - 1)(net)
|
||||
|
||||
w = jnp.concatenate([jnp.zeros((1,)),
|
||||
w])
|
||||
# make sure the knots sum to 1 and are in the interval 0,1
|
||||
k = jnp.concatenate([jnp.zeros((1, )), jnp.cumsum(jax.nn.softmax(k))])
|
||||
|
||||
# Augment with repeating points
|
||||
ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))])
|
||||
w = jnp.concatenate([jnp.zeros((1, )), w])
|
||||
|
||||
return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3)
|
||||
# Augment with repeating points
|
||||
ak = jnp.concatenate([jnp.zeros((3, )), k, jnp.ones((3, ))])
|
||||
|
||||
return _deBoorVectorized(jnp.clip(x / jnp.sqrt(3), 0, 1 - 1e-4), ak, w,
|
||||
3)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue