Replace assert with raise

This commit is contained in:
Yin Li 2020-09-13 14:22:40 -04:00
parent 55f61b9a70
commit a579c9655b

View File

@ -50,11 +50,12 @@ class FieldDataset(Dataset):
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns] tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
self.tgt_files = list(zip(* tgt_file_lists)) self.tgt_files = list(zip(* tgt_file_lists))
assert len(self.in_files) == len(self.tgt_files), \ if len(self.in_files) != len(self.tgt_files):
'number of input and target fields do not match' raise ValueError('number of input and target fields do not match')
self.nfile = len(self.in_files) self.nfile = len(self.in_files)
assert self.nfile > 0, 'file not found for {}'.format(in_patterns) if self.nfile == 0:
raise FileNotFoundError('file not found for {}'.format(in_patterns))
self.in_chan = [np.load(f, mmap_mode='r').shape[0] self.in_chan = [np.load(f, mmap_mode='r').shape[0]
for f in self.in_files[0]] for f in self.in_files[0]]
@ -65,14 +66,12 @@ class FieldDataset(Dataset):
self.size = np.asarray(self.size) self.size = np.asarray(self.size)
self.ndim = len(self.size) self.ndim = len(self.size)
if in_norms is not None: if in_norms is not None and len(in_patterns) != len(in_norms):
assert len(in_patterns) == len(in_norms), \ raise ValueError('numbers of input normalization functions and fields do not match')
'numbers of input normalization functions and fields do not match'
self.in_norms = in_norms self.in_norms = in_norms
if tgt_norms is not None: if tgt_norms is not None and len(tgt_patterns) != len(tgt_norms):
assert len(tgt_patterns) == len(tgt_norms), \ raise ValueError('numbers of target normalization functions and fields do not match')
'numbers of target normalization functions and fields do not match'
self.tgt_norms = tgt_norms self.tgt_norms = tgt_norms
self.callback_at = callback_at self.callback_at = callback_at