Fix bug due to strange numpy type conversion
This commit is contained in:
parent
d0ab596e11
commit
c61b6d6fab
@ -135,20 +135,20 @@ class FieldDataset(Dataset):
|
|||||||
# first add full fields when num_fields > num_GPU
|
# first add full fields when num_fields > num_GPU
|
||||||
for i in range(rank, self.nfile // world_size * world_size,
|
for i in range(rank, self.nfile // world_size * world_size,
|
||||||
world_size):
|
world_size):
|
||||||
self.samples.append(
|
self.samples.extend(list(
|
||||||
range(i * self.ncrop, (i + 1) * self.ncrop)
|
range(i * self.ncrop, (i + 1) * self.ncrop)
|
||||||
)
|
))
|
||||||
|
|
||||||
# then split the rest into fractions of fields
|
# then split the rest into fractions of fields
|
||||||
# drop the last incomplete batch of samples
|
# drop the last incomplete batch of samples
|
||||||
frac_start = self.nfile // world_size * world_size * self.ncrop
|
frac_start = self.nfile // world_size * world_size * self.ncrop
|
||||||
frac_samples = self.nfile % world_size * self.ncrop // world_size
|
frac_samples = self.nfile % world_size * self.ncrop // world_size
|
||||||
self.samples.append(range(frac_start + rank * frac_samples,
|
self.samples.extend(list(
|
||||||
frac_start + (rank + 1) * frac_samples))
|
range(frac_start + rank * frac_samples,
|
||||||
|
frac_start + (rank + 1) * frac_samples)
|
||||||
self.samples = np.concatenate(self.samples)
|
))
|
||||||
else:
|
else:
|
||||||
self.samples = np.arange(self.nfile * self.ncrop)
|
self.samples = list(range(self.nfile * self.ncrop))
|
||||||
self.nsample = len(self.samples)
|
self.nsample = len(self.samples)
|
||||||
|
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
|
Loading…
Reference in New Issue
Block a user