Replace assert with raise

This commit is contained in:
Yin Li 2020-09-13 14:22:40 -04:00
parent 55f61b9a70
commit a579c9655b

View File

@ -50,11 +50,12 @@ class FieldDataset(Dataset):
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
self.tgt_files = list(zip(* tgt_file_lists))
assert len(self.in_files) == len(self.tgt_files), \
'number of input and target fields do not match'
if len(self.in_files) != len(self.tgt_files):
raise ValueError('number of input and target fields do not match')
self.nfile = len(self.in_files)
assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
if self.nfile == 0:
raise FileNotFoundError('file not found for {}'.format(in_patterns))
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
for f in self.in_files[0]]
@ -65,14 +66,12 @@ class FieldDataset(Dataset):
self.size = np.asarray(self.size)
self.ndim = len(self.size)
if in_norms is not None:
assert len(in_patterns) == len(in_norms), \
'numbers of input normalization functions and fields do not match'
if in_norms is not None and len(in_patterns) != len(in_norms):
raise ValueError('numbers of input normalization functions and fields do not match')
self.in_norms = in_norms
if tgt_norms is not None:
assert len(tgt_patterns) == len(tgt_norms), \
'numbers of target normalization functions and fields do not match'
if tgt_norms is not None and len(tgt_patterns) != len(tgt_norms):
raise ValueError('numbers of target normalization functions and fields do not match')
self.tgt_norms = tgt_norms
self.callback_at = callback_at