Import patches from Deaglan fork

commit 15e35b2c91d7b4c0a0158747c8059e756d110fad
Author: Deaglan Bartlett <deaglan.bartlett@physics.ox.ac.uk>
Date:   Mon Sep 11 22:51:30 2023 +0200

    Fix power issue in plotting

commit 048e53bb447faa569dfc86a0b29dd1216974fc21
Author: Deaglan Bartlett <deaglan.bartlett@physics.ox.ac.uk>
Date:   Mon Sep 11 20:23:16 2023 +0200

    Fix tgt norm correction bug

commit 088fda3c4c799f3b6d1233ba7419201c78282483
Author: Deaglan Bartlett <deaglan.bartlett@physics.ox.ac.uk>
Date:   Tue Sep 5 15:38:21 2023 +0200

    Minor changes for coca
This commit is contained in:
Guilhem Lavaux 2024-04-02 16:19:35 +02:00
parent 352482d475
commit e24343e63d
4 changed files with 17 additions and 9 deletions

View File

@ -57,6 +57,10 @@ class FieldDataset(Dataset):
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns] tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
self.tgt_files = list(zip(* tgt_file_lists)) self.tgt_files = list(zip(* tgt_file_lists))
# self.style_files = self.style_files[:1]
# self.in_files = self.in_files[:1]
# self.tgt_files = self.tgt_files[:1]
if len(self.style_files) != len(self.in_files) != len(self.tgt_files): if len(self.style_files) != len(self.in_files) != len(self.tgt_files):
raise ValueError('number of style, input, and target files do not match') raise ValueError('number of style, input, and target files do not match')
self.nfile = len(self.in_files) self.nfile = len(self.in_files)
@ -64,7 +68,7 @@ class FieldDataset(Dataset):
if self.nfile == 0: if self.nfile == 0:
raise FileNotFoundError('file not found for {}'.format(in_patterns)) raise FileNotFoundError('file not found for {}'.format(in_patterns))
self.style_size = np.loadtxt(self.style_files[0]).shape[0] self.style_size = np.loadtxt(self.style_files[0], ndmin=1).shape[0]
self.in_chan = [np.load(f, mmap_mode='r').shape[0] self.in_chan = [np.load(f, mmap_mode='r').shape[0]
for f in self.in_files[0]] for f in self.in_files[0]]
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0] self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
@ -155,7 +159,7 @@ class FieldDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
ifile, icrop = divmod(idx, self.ncrop) ifile, icrop = divmod(idx, self.ncrop)
style = np.loadtxt(self.style_files[ifile]) style = np.loadtxt(self.style_files[ifile], ndmin=1)
in_fields = [np.load(f) for f in self.in_files[ifile]] in_fields = [np.load(f) for f in self.in_files[ifile]]
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]] tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
@ -188,13 +192,16 @@ class FieldDataset(Dataset):
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
if self.in_norms is not None: if self.in_norms is not None:
for norm, x in zip(self.in_norms, in_fields): in_norms = np.atleast_1d(self.in_norms[ifile])
for norm, x in zip(in_norms, in_fields):
norm = import_attr(norm, norms, callback_at=self.callback_at) norm = import_attr(norm, norms, callback_at=self.callback_at)
norm(x, **self.kwargs) norm(x, **self.kwargs)
if self.tgt_norms is not None: if self.tgt_norms is not None:
for norm, x in zip(self.tgt_norms, tgt_fields): tgt_norms = np.atleast_1d(self.tgt_norms[ifile])
norm = import_attr(norm, norms, callback_at=self.callback_at) for norm, x in zip(tgt_norms, tgt_fields):
norm(x, **self.kwargs) # norm = import_attr(norm, norms, callback_at=self.callback_at)
# norm(x, **self.kwargs)
x /= norm
if self.augment: if self.augment:
flip_axes = flip(in_fields, None, self.ndim) flip_axes = flip(in_fields, None, self.ndim)

View File

@ -95,7 +95,6 @@ class ConvStyled3d(nn.Module):
nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5), nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5),
mode='fan_in', nonlinearity='leaky_relu') mode='fan_in', nonlinearity='leaky_relu')
self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1 self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1
if resample is None: if resample is None:
K3 = (kernel_size,) * 3 K3 = (kernel_size,) * 3
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3)) self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
@ -156,7 +155,7 @@ class ConvStyled3d(nn.Module):
w = w.reshape(N * C0, C1, *K3) w = w.reshape(N * C0, C1, *K3)
x = x.reshape(1, N * Cin, *DHWin) x = x.reshape(1, N * Cin, *DHWin)
x = self.conv(x, w, bias=self.bias, stride=self.stride, groups=N) x = self.conv(x, w, bias=self.bias.repeat(N), stride=self.stride, groups=N)
_, _, *DHWout = x.shape _, _, *DHWout = x.shape
x = x.reshape(N, Cout, *DHWout) x = x.reshape(N, Cout, *DHWout)

View File

@ -292,6 +292,7 @@ def train(epoch, loader, model, criterion,
logger.add_scalar('loss/epoch/train', epoch_loss[0], logger.add_scalar('loss/epoch/train', epoch_loss[0],
global_step=epoch+1) global_step=epoch+1)
skip_chan = 0
fig = plt_slices( fig = plt_slices(
input[-1], output[-1, skip_chan:], target[-1, skip_chan:], input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
output[-1, skip_chan:] - target[-1, skip_chan:], output[-1, skip_chan:] - target[-1, skip_chan:],
@ -351,6 +352,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
if rank == 0: if rank == 0:
logger.add_scalar('loss/epoch/val', epoch_loss[0], logger.add_scalar('loss/epoch/val', epoch_loss[0],
global_step=epoch+1) global_step=epoch+1)
skip_chan = 0
fig = plt_slices( fig = plt_slices(
input[-1], output[-1, skip_chan:], target[-1, skip_chan:], input[-1], output[-1, skip_chan:], target[-1, skip_chan:],

View File

@ -146,7 +146,7 @@ def plt_power(*fields, dis=None, label=None, **kwargs):
ks, Ps = [], [] ks, Ps = [], []
for field in fields: for field in fields:
k, P, _ = power(field) k, P, _ = power.power(field)
ks.append(k) ks.append(k)
Ps.append(P) Ps.append(P)