diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 4ef0c9e..65a9867 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -56,6 +56,10 @@ class FieldDataset(Dataset): tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns] self.tgt_files = list(zip(* tgt_file_lists)) + + # self.style_files = self.style_files[:1] + # self.in_files = self.in_files[:1] + # self.tgt_files = self.tgt_files[:1] if len(self.style_files) != len(self.in_files) != len(self.tgt_files): raise ValueError('number of style, input, and target files do not match') @@ -64,7 +68,7 @@ class FieldDataset(Dataset): if self.nfile == 0: raise FileNotFoundError('file not found for {}'.format(in_patterns)) - self.style_size = np.loadtxt(self.style_files[0]).shape[0] + self.style_size = np.loadtxt(self.style_files[0], ndmin=1).shape[0] self.in_chan = [np.load(f, mmap_mode='r').shape[0] for f in self.in_files[0]] self.tgt_chan = [np.load(f, mmap_mode='r').shape[0] @@ -155,7 +159,7 @@ class FieldDataset(Dataset): def __getitem__(self, idx): ifile, icrop = divmod(idx, self.ncrop) - style = np.loadtxt(self.style_files[ifile]) + style = np.loadtxt(self.style_files[ifile], ndmin=1) in_fields = [np.load(f) for f in self.in_files[ifile]] tgt_fields = [np.load(f) for f in self.tgt_files[ifile]] @@ -188,13 +192,16 @@ class FieldDataset(Dataset): tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] if self.in_norms is not None: - for norm, x in zip(self.in_norms, in_fields): + in_norms = np.atleast_1d(self.in_norms[ifile]) + for norm, x in zip(in_norms, in_fields): norm = import_attr(norm, norms, callback_at=self.callback_at) norm(x, **self.kwargs) if self.tgt_norms is not None: - for norm, x in zip(self.tgt_norms, tgt_fields): - norm = import_attr(norm, norms, callback_at=self.callback_at) - norm(x, **self.kwargs) + tgt_norms = np.atleast_1d(self.tgt_norms[ifile]) + for norm, x in zip(tgt_norms, tgt_fields): + # norm = import_attr(norm, norms, callback_at=self.callback_at) + # norm(x, **self.kwargs) + x /= norm if self.augment: flip_axes = flip(in_fields, None, self.ndim) diff --git a/map2map/models/style.py b/map2map/models/style.py index 2c1c80c..24a8349 100644 --- a/map2map/models/style.py +++ b/map2map/models/style.py @@ -95,7 +95,6 @@ class ConvStyled3d(nn.Module): nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5), mode='fan_in', nonlinearity='leaky_relu') self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1 - if resample is None: K3 = (kernel_size,) * 3 self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3)) @@ -156,7 +155,7 @@ class ConvStyled3d(nn.Module): w = w.reshape(N * C0, C1, *K3) x = x.reshape(1, N * Cin, *DHWin) - x = self.conv(x, w, bias=self.bias, stride=self.stride, groups=N) + x = self.conv(x, w, bias=self.bias.repeat(N), stride=self.stride, groups=N) _, _, *DHWout = x.shape x = x.reshape(N, Cout, *DHWout) diff --git a/map2map/train.py b/map2map/train.py index d9690d8..db80b16 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -292,6 +292,7 @@ def train(epoch, loader, model, criterion, logger.add_scalar('loss/epoch/train', epoch_loss[0], global_step=epoch+1) + skip_chan = 0 fig = plt_slices( input[-1], output[-1, skip_chan:], target[-1, skip_chan:], output[-1, skip_chan:] - target[-1, skip_chan:], @@ -351,6 +352,7 @@ def validate(epoch, loader, model, criterion, logger, device, args): if rank == 0: logger.add_scalar('loss/epoch/val', epoch_loss[0], global_step=epoch+1) + skip_chan = 0 fig = plt_slices( input[-1], output[-1, skip_chan:], target[-1, skip_chan:], diff --git a/map2map/utils/figures.py b/map2map/utils/figures.py index 316bbf6..c6ae4a8 100644 --- a/map2map/utils/figures.py +++ b/map2map/utils/figures.py @@ -146,7 +146,7 @@ def plt_power(*fields, dis=None, label=None, **kwargs): ks, Ps = [], [] for field in fields: - k, P, _ = power(field) + k, P, _ = power.power(field) ks.append(k) Ps.append(P)