diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 19d1ddb..e36dc99 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -135,20 +135,20 @@ class FieldDataset(Dataset): # first add full fields when num_fields > num_GPU for i in range(rank, self.nfile // world_size * world_size, world_size): - self.samples.append( + self.samples.extend(list( range(i * self.ncrop, (i + 1) * self.ncrop) - ) + )) # then split the rest into fractions of fields # drop the last incomplete batch of samples frac_start = self.nfile // world_size * world_size * self.ncrop frac_samples = self.nfile % world_size * self.ncrop // world_size - self.samples.append(range(frac_start + rank * frac_samples, - frac_start + (rank + 1) * frac_samples)) - - self.samples = np.concatenate(self.samples) + self.samples.extend(list( + range(frac_start + rank * frac_samples, + frac_start + (rank + 1) * frac_samples) + )) else: - self.samples = np.arange(self.nfile * self.ncrop) + self.samples = list(range(self.nfile * self.ncrop)) self.nsample = len(self.samples) self.rank = rank