Merge remote-tracking branch 'refs/remotes/origin/lag2eul' into lag2eul
This commit is contained in:
commit
f98b0fbbd0
1 changed files with 4 additions and 4 deletions
|
@ -38,11 +38,11 @@ class DistFieldSampler(Sampler):
|
||||||
self.div_shuffle_dist = div_shuffle_dist
|
self.div_shuffle_dist = div_shuffle_dist
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
# deterministically shuffle based on epoch
|
|
||||||
g = torch.Generator()
|
|
||||||
g.manual_seed(self.epoch)
|
|
||||||
|
|
||||||
if self.shuffle:
|
if self.shuffle:
|
||||||
|
# deterministically shuffle based on epoch
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(self.epoch)
|
||||||
|
|
||||||
if self.div_data:
|
if self.div_data:
|
||||||
# shuffle files
|
# shuffle files
|
||||||
ind = torch.randperm(self.nfile, generator=g)
|
ind = torch.randperm(self.nfile, generator=g)
|
||||||
|
|
Loading…
Add table
Reference in a new issue