Import patches from Deaglan fork
commit 15e35b2c91d7b4c0a0158747c8059e756d110fad Author: Deaglan Bartlett <deaglan.bartlett@physics.ox.ac.uk> Date: Mon Sep 11 22:51:30 2023 +0200 Fix power issue in plotting commit 048e53bb447faa569dfc86a0b29dd1216974fc21 Author: Deaglan Bartlett <deaglan.bartlett@physics.ox.ac.uk> Date: Mon Sep 11 20:23:16 2023 +0200 Fix tgt norm correction bug commit 088fda3c4c799f3b6d1233ba7419201c78282483 Author: Deaglan Bartlett <deaglan.bartlett@physics.ox.ac.uk> Date: Tue Sep 5 15:38:21 2023 +0200 Minor changes for coca
This commit is contained in:
parent
352482d475
commit
e24343e63d
@ -56,6 +56,10 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
|
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
|
||||||
self.tgt_files = list(zip(* tgt_file_lists))
|
self.tgt_files = list(zip(* tgt_file_lists))
|
||||||
|
|
||||||
|
# self.style_files = self.style_files[:1]
|
||||||
|
# self.in_files = self.in_files[:1]
|
||||||
|
# self.tgt_files = self.tgt_files[:1]
|
||||||
|
|
||||||
if len(self.style_files) != len(self.in_files) != len(self.tgt_files):
|
if len(self.style_files) != len(self.in_files) != len(self.tgt_files):
|
||||||
raise ValueError('number of style, input, and target files do not match')
|
raise ValueError('number of style, input, and target files do not match')
|
||||||
@ -64,7 +68,7 @@ class FieldDataset(Dataset):
|
|||||||
if self.nfile == 0:
|
if self.nfile == 0:
|
||||||
raise FileNotFoundError('file not found for {}'.format(in_patterns))
|
raise FileNotFoundError('file not found for {}'.format(in_patterns))
|
||||||
|
|
||||||
self.style_size = np.loadtxt(self.style_files[0]).shape[0]
|
self.style_size = np.loadtxt(self.style_files[0], ndmin=1).shape[0]
|
||||||
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||||
for f in self.in_files[0]]
|
for f in self.in_files[0]]
|
||||||
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
|
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||||
@ -155,7 +159,7 @@ class FieldDataset(Dataset):
|
|||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
ifile, icrop = divmod(idx, self.ncrop)
|
ifile, icrop = divmod(idx, self.ncrop)
|
||||||
|
|
||||||
style = np.loadtxt(self.style_files[ifile])
|
style = np.loadtxt(self.style_files[ifile], ndmin=1)
|
||||||
in_fields = [np.load(f) for f in self.in_files[ifile]]
|
in_fields = [np.load(f) for f in self.in_files[ifile]]
|
||||||
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
|
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
|
||||||
|
|
||||||
@ -188,13 +192,16 @@ class FieldDataset(Dataset):
|
|||||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||||
|
|
||||||
if self.in_norms is not None:
|
if self.in_norms is not None:
|
||||||
for norm, x in zip(self.in_norms, in_fields):
|
in_norms = np.atleast_1d(self.in_norms[ifile])
|
||||||
|
for norm, x in zip(in_norms, in_fields):
|
||||||
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
||||||
norm(x, **self.kwargs)
|
norm(x, **self.kwargs)
|
||||||
if self.tgt_norms is not None:
|
if self.tgt_norms is not None:
|
||||||
for norm, x in zip(self.tgt_norms, tgt_fields):
|
tgt_norms = np.atleast_1d(self.tgt_norms[ifile])
|
||||||
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
for norm, x in zip(tgt_norms, tgt_fields):
|
||||||
norm(x, **self.kwargs)
|
# norm = import_attr(norm, norms, callback_at=self.callback_at)
|
||||||
|
# norm(x, **self.kwargs)
|
||||||
|
x /= norm
|
||||||
|
|
||||||
if self.augment:
|
if self.augment:
|
||||||
flip_axes = flip(in_fields, None, self.ndim)
|
flip_axes = flip(in_fields, None, self.ndim)
|
||||||
|
@ -95,7 +95,6 @@ class ConvStyled3d(nn.Module):
|
|||||||
nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5),
|
nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5),
|
||||||
mode='fan_in', nonlinearity='leaky_relu')
|
mode='fan_in', nonlinearity='leaky_relu')
|
||||||
self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1
|
self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1
|
||||||
|
|
||||||
if resample is None:
|
if resample is None:
|
||||||
K3 = (kernel_size,) * 3
|
K3 = (kernel_size,) * 3
|
||||||
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
||||||
@ -156,7 +155,7 @@ class ConvStyled3d(nn.Module):
|
|||||||
|
|
||||||
w = w.reshape(N * C0, C1, *K3)
|
w = w.reshape(N * C0, C1, *K3)
|
||||||
x = x.reshape(1, N * Cin, *DHWin)
|
x = x.reshape(1, N * Cin, *DHWin)
|
||||||
x = self.conv(x, w, bias=self.bias, stride=self.stride, groups=N)
|
x = self.conv(x, w, bias=self.bias.repeat(N), stride=self.stride, groups=N)
|
||||||
_, _, *DHWout = x.shape
|
_, _, *DHWout = x.shape
|
||||||
x = x.reshape(N, Cout, *DHWout)
|
x = x.reshape(N, Cout, *DHWout)
|
||||||
|
|
||||||
|
@ -292,6 +292,7 @@ def train(epoch, loader, model, criterion,
|
|||||||
logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
|
|
||||||
|
skip_chan = 0
|
||||||
fig = plt_slices(
|
fig = plt_slices(
|
||||||
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
||||||
output[-1, skip_chan:] - target[-1, skip_chan:],
|
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||||
@ -351,6 +352,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
|
skip_chan = 0
|
||||||
|
|
||||||
fig = plt_slices(
|
fig = plt_slices(
|
||||||
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
||||||
|
@ -146,7 +146,7 @@ def plt_power(*fields, dis=None, label=None, **kwargs):
|
|||||||
|
|
||||||
ks, Ps = [], []
|
ks, Ps = [], []
|
||||||
for field in fields:
|
for field in fields:
|
||||||
k, P, _ = power(field)
|
k, P, _ = power.power(field)
|
||||||
ks.append(k)
|
ks.append(k)
|
||||||
Ps.append(P)
|
Ps.append(P)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user