Merge remote-tracking branch 'refs/remotes/origin/lag2eul' into lag2eul

This commit is contained in:
Yin Li 2020-07-22 02:38:45 -04:00
commit f98b0fbbd0

View File

@ -38,11 +38,11 @@ class DistFieldSampler(Sampler):
self.div_shuffle_dist = div_shuffle_dist
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.div_data:
# shuffle files
ind = torch.randperm(self.nfile, generator=g)