From e24343e63d3b0599f67b458e69e6ce9b8824a73c Mon Sep 17 00:00:00 2001 From: Guilhem Lavaux Date: Tue, 2 Apr 2024 16:19:35 +0200 Subject: [PATCH] Import patches from Deaglan fork commit 15e35b2c91d7b4c0a0158747c8059e756d110fad Author: Deaglan Bartlett Date: Mon Sep 11 22:51:30 2023 +0200 Fix power issue in plotting commit 048e53bb447faa569dfc86a0b29dd1216974fc21 Author: Deaglan Bartlett Date: Mon Sep 11 20:23:16 2023 +0200 Fix tgt norm correction bug commit 088fda3c4c799f3b6d1233ba7419201c78282483 Author: Deaglan Bartlett Date: Tue Sep 5 15:38:21 2023 +0200 Minor changes for coca --- map2map/data/fields.py | 19 +++++++++++++------ map2map/models/style.py | 3 +-- map2map/train.py | 2 ++ map2map/utils/figures.py | 2 +- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 4ef0c9e..65a9867 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -56,6 +56,10 @@ class FieldDataset(Dataset): tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns] self.tgt_files = list(zip(* tgt_file_lists)) + + # self.style_files = self.style_files[:1] + # self.in_files = self.in_files[:1] + # self.tgt_files = self.tgt_files[:1] if len(self.style_files) != len(self.in_files) != len(self.tgt_files): raise ValueError('number of style, input, and target files do not match') @@ -64,7 +68,7 @@ class FieldDataset(Dataset): if self.nfile == 0: raise FileNotFoundError('file not found for {}'.format(in_patterns)) - self.style_size = np.loadtxt(self.style_files[0]).shape[0] + self.style_size = np.loadtxt(self.style_files[0], ndmin=1).shape[0] self.in_chan = [np.load(f, mmap_mode='r').shape[0] for f in self.in_files[0]] self.tgt_chan = [np.load(f, mmap_mode='r').shape[0] @@ -155,7 +159,7 @@ class FieldDataset(Dataset): def __getitem__(self, idx): ifile, icrop = divmod(idx, self.ncrop) - style = np.loadtxt(self.style_files[ifile]) + style = np.loadtxt(self.style_files[ifile], ndmin=1) in_fields = [np.load(f) for f in self.in_files[ifile]] tgt_fields = [np.load(f) for f in self.tgt_files[ifile]] @@ -188,13 +192,16 @@ class FieldDataset(Dataset): tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] if self.in_norms is not None: - for norm, x in zip(self.in_norms, in_fields): + in_norms = np.atleast_1d(self.in_norms[ifile]) + for norm, x in zip(in_norms, in_fields): norm = import_attr(norm, norms, callback_at=self.callback_at) norm(x, **self.kwargs) if self.tgt_norms is not None: - for norm, x in zip(self.tgt_norms, tgt_fields): - norm = import_attr(norm, norms, callback_at=self.callback_at) - norm(x, **self.kwargs) + tgt_norms = np.atleast_1d(self.tgt_norms[ifile]) + for norm, x in zip(tgt_norms, tgt_fields): + # norm = import_attr(norm, norms, callback_at=self.callback_at) + # norm(x, **self.kwargs) + x /= norm if self.augment: flip_axes = flip(in_fields, None, self.ndim) diff --git a/map2map/models/style.py b/map2map/models/style.py index 2c1c80c..24a8349 100644 --- a/map2map/models/style.py +++ b/map2map/models/style.py @@ -95,7 +95,6 @@ class ConvStyled3d(nn.Module): nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5), mode='fan_in', nonlinearity='leaky_relu') self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1 - if resample is None: K3 = (kernel_size,) * 3 self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3)) @@ -156,7 +155,7 @@ class ConvStyled3d(nn.Module): w = w.reshape(N * C0, C1, *K3) x = x.reshape(1, N * Cin, *DHWin) - x = self.conv(x, w, bias=self.bias, stride=self.stride, groups=N) + x = self.conv(x, w, bias=self.bias.repeat(N), stride=self.stride, groups=N) _, _, *DHWout = x.shape x = x.reshape(N, Cout, *DHWout) diff --git a/map2map/train.py b/map2map/train.py index d9690d8..db80b16 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -292,6 +292,7 @@ def train(epoch, loader, model, criterion, logger.add_scalar('loss/epoch/train', epoch_loss[0], global_step=epoch+1) + skip_chan = 0 fig = plt_slices( input[-1], output[-1, skip_chan:], target[-1, skip_chan:], output[-1, skip_chan:] - target[-1, skip_chan:], @@ -351,6 +352,7 @@ def validate(epoch, loader, model, criterion, logger, device, args): if rank == 0: logger.add_scalar('loss/epoch/val', epoch_loss[0], global_step=epoch+1) + skip_chan = 0 fig = plt_slices( input[-1], output[-1, skip_chan:], target[-1, skip_chan:], diff --git a/map2map/utils/figures.py b/map2map/utils/figures.py index 316bbf6..c6ae4a8 100644 --- a/map2map/utils/figures.py +++ b/map2map/utils/figures.py @@ -146,7 +146,7 @@ def plt_power(*fields, dis=None, label=None, **kwargs): ks, Ps = [], [] for field in fields: - k, P, _ = power(field) + k, P, _ = power.power(field) ks.append(k) Ps.append(P)