Fix bug due to strange numpy type conversion

This commit is contained in:
Yin Li 2020-07-13 23:50:47 -04:00
parent d0ab596e11
commit c61b6d6fab

View File

@ -135,20 +135,20 @@ class FieldDataset(Dataset):
# first add full fields when num_fields > num_GPU
for i in range(rank, self.nfile // world_size * world_size,
world_size):
self.samples.append(
self.samples.extend(list(
range(i * self.ncrop, (i + 1) * self.ncrop)
)
))
# then split the rest into fractions of fields
# drop the last incomplete batch of samples
frac_start = self.nfile // world_size * world_size * self.ncrop
frac_samples = self.nfile % world_size * self.ncrop // world_size
self.samples.append(range(frac_start + rank * frac_samples,
frac_start + (rank + 1) * frac_samples))
self.samples = np.concatenate(self.samples)
self.samples.extend(list(
range(frac_start + rank * frac_samples,
frac_start + (rank + 1) * frac_samples)
))
else:
self.samples = np.arange(self.nfile * self.ncrop)
self.samples = list(range(self.nfile * self.ncrop))
self.nsample = len(self.samples)
self.rank = rank