Merge remote-tracking branch 'refs/remotes/origin/lag2eul' into lag2eul
This commit is contained in:
commit
f98b0fbbd0
@ -38,11 +38,11 @@ class DistFieldSampler(Sampler):
|
|||||||
self.div_shuffle_dist = div_shuffle_dist
|
self.div_shuffle_dist = div_shuffle_dist
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
if self.shuffle:
|
||||||
# deterministically shuffle based on epoch
|
# deterministically shuffle based on epoch
|
||||||
g = torch.Generator()
|
g = torch.Generator()
|
||||||
g.manual_seed(self.epoch)
|
g.manual_seed(self.epoch)
|
||||||
|
|
||||||
if self.shuffle:
|
|
||||||
if self.div_data:
|
if self.div_data:
|
||||||
# shuffle files
|
# shuffle files
|
||||||
ind = torch.randperm(self.nfile, generator=g)
|
ind = torch.randperm(self.nfile, generator=g)
|
||||||
|
Loading…
Reference in New Issue
Block a user