Merge remote-tracking branch 'refs/remotes/origin/lag2eul' into lag2eul
This commit is contained in:
commit
f98b0fbbd0
@ -38,11 +38,11 @@ class DistFieldSampler(Sampler):
|
||||
self.div_shuffle_dist = div_shuffle_dist
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
|
||||
if self.shuffle:
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
|
||||
if self.div_data:
|
||||
# shuffle files
|
||||
ind = torch.randperm(self.nfile, generator=g)
|
||||
|
Loading…
Reference in New Issue
Block a user