Fix bug due to strange numpy type conversion
This commit is contained in:
parent
d0ab596e11
commit
c61b6d6fab
@ -135,20 +135,20 @@ class FieldDataset(Dataset):
|
||||
# first add full fields when num_fields > num_GPU
|
||||
for i in range(rank, self.nfile // world_size * world_size,
|
||||
world_size):
|
||||
self.samples.append(
|
||||
self.samples.extend(list(
|
||||
range(i * self.ncrop, (i + 1) * self.ncrop)
|
||||
)
|
||||
))
|
||||
|
||||
# then split the rest into fractions of fields
|
||||
# drop the last incomplete batch of samples
|
||||
frac_start = self.nfile // world_size * world_size * self.ncrop
|
||||
frac_samples = self.nfile % world_size * self.ncrop // world_size
|
||||
self.samples.append(range(frac_start + rank * frac_samples,
|
||||
frac_start + (rank + 1) * frac_samples))
|
||||
|
||||
self.samples = np.concatenate(self.samples)
|
||||
self.samples.extend(list(
|
||||
range(frac_start + rank * frac_samples,
|
||||
frac_start + (rank + 1) * frac_samples)
|
||||
))
|
||||
else:
|
||||
self.samples = np.arange(self.nfile * self.ncrop)
|
||||
self.samples = list(range(self.nfile * self.ncrop))
|
||||
self.nsample = len(self.samples)
|
||||
|
||||
self.rank = rank
|
||||
|
Loading…
Reference in New Issue
Block a user