Replace assert with raise
This commit is contained in:
parent
55f61b9a70
commit
a579c9655b
@ -50,11 +50,12 @@ class FieldDataset(Dataset):
|
||||
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
|
||||
self.tgt_files = list(zip(* tgt_file_lists))
|
||||
|
||||
assert len(self.in_files) == len(self.tgt_files), \
|
||||
'number of input and target fields do not match'
|
||||
if len(self.in_files) != len(self.tgt_files):
|
||||
raise ValueError('number of input and target fields do not match')
|
||||
self.nfile = len(self.in_files)
|
||||
|
||||
assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
|
||||
if self.nfile == 0:
|
||||
raise FileNotFoundError('file not found for {}'.format(in_patterns))
|
||||
|
||||
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||
for f in self.in_files[0]]
|
||||
@ -65,14 +66,12 @@ class FieldDataset(Dataset):
|
||||
self.size = np.asarray(self.size)
|
||||
self.ndim = len(self.size)
|
||||
|
||||
if in_norms is not None:
|
||||
assert len(in_patterns) == len(in_norms), \
|
||||
'numbers of input normalization functions and fields do not match'
|
||||
if in_norms is not None and len(in_patterns) != len(in_norms):
|
||||
raise ValueError('numbers of input normalization functions and fields do not match')
|
||||
self.in_norms = in_norms
|
||||
|
||||
if tgt_norms is not None:
|
||||
assert len(tgt_patterns) == len(tgt_norms), \
|
||||
'numbers of target normalization functions and fields do not match'
|
||||
if tgt_norms is not None and len(tgt_patterns) != len(tgt_norms):
|
||||
raise ValueError('numbers of target normalization functions and fields do not match')
|
||||
self.tgt_norms = tgt_norms
|
||||
|
||||
self.callback_at = callback_at
|
||||
|
Loading…
Reference in New Issue
Block a user