Fix seeding bug introduced in the completely wrong commit f64b1e4

This commit is contained in:
Yin Li 2020-01-06 20:20:05 -05:00
parent 848dc87169
commit 9cf97b3ac1
7 changed files with 11 additions and 14 deletions

View File

@ -63,7 +63,7 @@ def add_train_args(parser):
# help='momentum') # help='momentum')
parser.add_argument('--weight-decay', default=0., type=float, parser.add_argument('--weight-decay', default=0., type=float,
help='weight decay') help='weight decay')
parser.add_argument('--seed', type=int, parser.add_argument('--seed', default=42, type=int,
help='seed for initializing training') help='seed for initializing training')
parser.add_argument('--div-data', action='store_true', parser.add_argument('--div-data', action='store_true',

View File

@ -45,7 +45,7 @@ class FieldDataset(Dataset):
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0]) self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0]) self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
self.size = np.load(self.in_files[0][0]).shape[-3:] self.size = np.load(self.in_files[0][0]).shape[1:]
self.size = np.asarray(self.size) self.size = np.asarray(self.size)
self.ndim = len(self.size) self.ndim = len(self.size)

View File

@ -1,6 +1,5 @@
import os import os
import shutil import shutil
import random
import torch import torch
from torch.multiprocessing import spawn from torch.multiprocessing import spawn
from torch.distributed import init_process_group, destroy_process_group, all_reduce from torch.distributed import init_process_group, destroy_process_group, all_reduce
@ -15,8 +14,6 @@ from .models import narrow_like
def node_worker(args): def node_worker(args):
if args.seed is None:
args.seed = random.randint(0, 65535)
torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker? torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker?
#torch.backends.cudnn.deterministic = True # NOTE: test perf #torch.backends.cudnn.deterministic = True # NOTE: test perf
@ -155,8 +152,8 @@ def gpu_worker(local_rank, args):
if min_loss is None or val_loss < min_loss: if min_loss is None or val_loss < min_loss:
min_loss = val_loss min_loss = val_loss
shutil.copyfile(ckpt_file, best_file.format(epoch + 1)) shutil.copyfile(ckpt_file, best_file.format(epoch + 1))
if os.path.isfile(best_file.format(epoch)): #if os.path.isfile(best_file.format(epoch)):
os.remove(best_file.format(epoch)) # os.remove(best_file.format(epoch))
destroy_process_group() destroy_process_group()

View File

@ -27,7 +27,7 @@ tgt_dir="nonlin"
test_dirs="*99" test_dirs="*99"
files="dis/512x000.npy" files="dis.npy"
in_files="$files" in_files="$files"
tgt_files="$files" tgt_files="$files"

View File

@ -29,7 +29,7 @@ tgt_dir="nonlin"
train_dirs="*[0-8]" train_dirs="*[0-8]"
val_dirs="*[0-8]9" val_dirs="*[0-8]9"
files="dis/512x000.npy" files="dis.npy"
in_files="$files" in_files="$files"
tgt_files="$files" tgt_files="$files"
@ -41,7 +41,7 @@ srun m2m.py train \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--norms cosmology.dis --augment --crop 100 --pad 42 \ --norms cosmology.dis --augment --crop 100 --pad 42 \
--model VNet \ --model VNet \
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 \ --epochs 128 --lr 0.001 --batches 1 --loader-workers 0 --seed $RANDOM \
--cache --div-data --cache --div-data
# --load-state checkpoint.pth \ # --load-state checkpoint.pth \

View File

@ -27,7 +27,7 @@ tgt_dir="nonlin"
test_dirs="*99" test_dirs="*99"
files="vel/512x000.npy" files="vel.npy"
in_files="$files" in_files="$files"
tgt_files="$files" tgt_files="$files"

View File

@ -29,7 +29,7 @@ tgt_dir="nonlin"
train_dirs="*[0-8]" train_dirs="*[0-8]"
val_dirs="*[0-8]9" val_dirs="*[0-8]9"
files="vel/512x000.npy" files="vel.npy"
in_files="$files" in_files="$files"
tgt_files="$files" tgt_files="$files"
@ -41,7 +41,7 @@ srun m2m.py train \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--norms cosmology.vel --augment --crop 100 --pad 42 \ --norms cosmology.vel --augment --crop 100 --pad 42 \
--model VNet \ --model VNet \
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 \ --epochs 128 --lr 0.001 --batches 1 --loader-workers 0 --seed $RANDOM \
--cache --div-data --cache --div-data
# --load-state checkpoint.pth \ # --load-state checkpoint.pth \