Import patches from Deaglan fork
commit 15e35b2c91d7b4c0a0158747c8059e756d110fad Author: Deaglan Bartlett <deaglan.bartlett@physics.ox.ac.uk> Date: Mon Sep 11 22:51:30 2023 +0200 Fix power issue in plotting commit 048e53bb447faa569dfc86a0b29dd1216974fc21 Author: Deaglan Bartlett <deaglan.bartlett@physics.ox.ac.uk> Date: Mon Sep 11 20:23:16 2023 +0200 Fix tgt norm correction bug commit 088fda3c4c799f3b6d1233ba7419201c78282483 Author: Deaglan Bartlett <deaglan.bartlett@physics.ox.ac.uk> Date: Tue Sep 5 15:38:21 2023 +0200 Minor changes for coca
This commit is contained in:
parent
352482d475
commit
e24343e63d
4 changed files with 17 additions and 9 deletions
|
@ -56,6 +56,10 @@ class FieldDataset(Dataset):
|
|||
|
||||
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
|
||||
self.tgt_files = list(zip(* tgt_file_lists))
|
||||
|
||||
# self.style_files = self.style_files[:1]
|
||||
# self.in_files = self.in_files[:1]
|
||||
# self.tgt_files = self.tgt_files[:1]
|
||||
|
||||
if len(self.style_files) != len(self.in_files) != len(self.tgt_files):
|
||||
raise ValueError('number of style, input, and target files do not match')
|
||||
|
@ -64,7 +68,7 @@ class FieldDataset(Dataset):
|
|||
if self.nfile == 0:
|
||||
raise FileNotFoundError('file not found for {}'.format(in_patterns))
|
||||
|
||||
self.style_size = np.loadtxt(self.style_files[0]).shape[0]
|
||||
self.style_size = np.loadtxt(self.style_files[0], ndmin=1).shape[0]
|
||||
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||
for f in self.in_files[0]]
|
||||
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||
|
@ -155,7 +159,7 @@ class FieldDataset(Dataset):
|
|||
def __getitem__(self, idx):
|
||||
ifile, icrop = divmod(idx, self.ncrop)
|
||||
|
||||
style = np.loadtxt(self.style_files[ifile])
|
||||
style = np.loadtxt(self.style_files[ifile], ndmin=1)
|
||||
in_fields = [np.load(f) for f in self.in_files[ifile]]
|
||||
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
|
||||
|
||||
|
@ -188,13 +192,16 @@ class FieldDataset(Dataset):
|
|||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||
|
||||
if self.in_norms is not None:
|
||||
for norm, x in zip(self.in_norms, in_fields):
|
||||
in_norms = np.atleast_1d(self.in_norms[ifile])
|
||||
for norm, x in zip(in_norms, in_fields):
|
||||
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
||||
norm(x, **self.kwargs)
|
||||
if self.tgt_norms is not None:
|
||||
for norm, x in zip(self.tgt_norms, tgt_fields):
|
||||
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
||||
norm(x, **self.kwargs)
|
||||
tgt_norms = np.atleast_1d(self.tgt_norms[ifile])
|
||||
for norm, x in zip(tgt_norms, tgt_fields):
|
||||
# norm = import_attr(norm, norms, callback_at=self.callback_at)
|
||||
# norm(x, **self.kwargs)
|
||||
x /= norm
|
||||
|
||||
if self.augment:
|
||||
flip_axes = flip(in_fields, None, self.ndim)
|
||||
|
|
|
@ -95,7 +95,6 @@ class ConvStyled3d(nn.Module):
|
|||
nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5),
|
||||
mode='fan_in', nonlinearity='leaky_relu')
|
||||
self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1
|
||||
|
||||
if resample is None:
|
||||
K3 = (kernel_size,) * 3
|
||||
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
||||
|
@ -156,7 +155,7 @@ class ConvStyled3d(nn.Module):
|
|||
|
||||
w = w.reshape(N * C0, C1, *K3)
|
||||
x = x.reshape(1, N * Cin, *DHWin)
|
||||
x = self.conv(x, w, bias=self.bias, stride=self.stride, groups=N)
|
||||
x = self.conv(x, w, bias=self.bias.repeat(N), stride=self.stride, groups=N)
|
||||
_, _, *DHWout = x.shape
|
||||
x = x.reshape(N, Cout, *DHWout)
|
||||
|
||||
|
|
|
@ -292,6 +292,7 @@ def train(epoch, loader, model, criterion,
|
|||
logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
|
||||
skip_chan = 0
|
||||
fig = plt_slices(
|
||||
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
||||
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||
|
@ -351,6 +352,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||
if rank == 0:
|
||||
logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
skip_chan = 0
|
||||
|
||||
fig = plt_slices(
|
||||
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
||||
|
|
|
@ -146,7 +146,7 @@ def plt_power(*fields, dis=None, label=None, **kwargs):
|
|||
|
||||
ks, Ps = [], []
|
||||
for field in fields:
|
||||
k, P, _ = power(field)
|
||||
k, P, _ = power.power(field)
|
||||
ks.append(k)
|
||||
Ps.append(P)
|
||||
|
||||
|
|
Loading…
Reference in a new issue