Fix seeding bug introduced in the completely wrong commit f64b1e4
This commit is contained in:
parent
848dc87169
commit
9cf97b3ac1
@ -30,7 +30,7 @@ def add_common_args(parser):
|
|||||||
help='path to load model, optimizer, rng, etc.')
|
help='path to load model, optimizer, rng, etc.')
|
||||||
|
|
||||||
parser.add_argument('--batches', default=1, type=int,
|
parser.add_argument('--batches', default=1, type=int,
|
||||||
help='mini-batch size, per GPU in training or in total in testing')
|
help='mini-batch size, per GPU in training or in total in testing')
|
||||||
parser.add_argument('--loader-workers', default=0, type=int,
|
parser.add_argument('--loader-workers', default=0, type=int,
|
||||||
help='number of data loading workers, per GPU in training or '
|
help='number of data loading workers, per GPU in training or '
|
||||||
'in total in testing')
|
'in total in testing')
|
||||||
@ -63,7 +63,7 @@ def add_train_args(parser):
|
|||||||
# help='momentum')
|
# help='momentum')
|
||||||
parser.add_argument('--weight-decay', default=0., type=float,
|
parser.add_argument('--weight-decay', default=0., type=float,
|
||||||
help='weight decay')
|
help='weight decay')
|
||||||
parser.add_argument('--seed', type=int,
|
parser.add_argument('--seed', default=42, type=int,
|
||||||
help='seed for initializing training')
|
help='seed for initializing training')
|
||||||
|
|
||||||
parser.add_argument('--div-data', action='store_true',
|
parser.add_argument('--div-data', action='store_true',
|
||||||
|
@ -45,7 +45,7 @@ class FieldDataset(Dataset):
|
|||||||
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
|
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
|
||||||
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
|
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
|
||||||
|
|
||||||
self.size = np.load(self.in_files[0][0]).shape[-3:]
|
self.size = np.load(self.in_files[0][0]).shape[1:]
|
||||||
self.size = np.asarray(self.size)
|
self.size = np.asarray(self.size)
|
||||||
self.ndim = len(self.size)
|
self.ndim = len(self.size)
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import random
|
|
||||||
import torch
|
import torch
|
||||||
from torch.multiprocessing import spawn
|
from torch.multiprocessing import spawn
|
||||||
from torch.distributed import init_process_group, destroy_process_group, all_reduce
|
from torch.distributed import init_process_group, destroy_process_group, all_reduce
|
||||||
@ -15,8 +14,6 @@ from .models import narrow_like
|
|||||||
|
|
||||||
|
|
||||||
def node_worker(args):
|
def node_worker(args):
|
||||||
if args.seed is None:
|
|
||||||
args.seed = random.randint(0, 65535)
|
|
||||||
torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker?
|
torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker?
|
||||||
#torch.backends.cudnn.deterministic = True # NOTE: test perf
|
#torch.backends.cudnn.deterministic = True # NOTE: test perf
|
||||||
|
|
||||||
@ -155,8 +152,8 @@ def gpu_worker(local_rank, args):
|
|||||||
if min_loss is None or val_loss < min_loss:
|
if min_loss is None or val_loss < min_loss:
|
||||||
min_loss = val_loss
|
min_loss = val_loss
|
||||||
shutil.copyfile(ckpt_file, best_file.format(epoch + 1))
|
shutil.copyfile(ckpt_file, best_file.format(epoch + 1))
|
||||||
if os.path.isfile(best_file.format(epoch)):
|
#if os.path.isfile(best_file.format(epoch)):
|
||||||
os.remove(best_file.format(epoch))
|
# os.remove(best_file.format(epoch))
|
||||||
|
|
||||||
destroy_process_group()
|
destroy_process_group()
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ tgt_dir="nonlin"
|
|||||||
|
|
||||||
test_dirs="*99"
|
test_dirs="*99"
|
||||||
|
|
||||||
files="dis/512x000.npy"
|
files="dis.npy"
|
||||||
in_files="$files"
|
in_files="$files"
|
||||||
tgt_files="$files"
|
tgt_files="$files"
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ tgt_dir="nonlin"
|
|||||||
train_dirs="*[0-8]"
|
train_dirs="*[0-8]"
|
||||||
val_dirs="*[0-8]9"
|
val_dirs="*[0-8]9"
|
||||||
|
|
||||||
files="dis/512x000.npy"
|
files="dis.npy"
|
||||||
in_files="$files"
|
in_files="$files"
|
||||||
tgt_files="$files"
|
tgt_files="$files"
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ srun m2m.py train \
|
|||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--norms cosmology.dis --augment --crop 100 --pad 42 \
|
--norms cosmology.dis --augment --crop 100 --pad 42 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 \
|
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 --seed $RANDOM \
|
||||||
--cache --div-data
|
--cache --div-data
|
||||||
# --load-state checkpoint.pth \
|
# --load-state checkpoint.pth \
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ tgt_dir="nonlin"
|
|||||||
|
|
||||||
test_dirs="*99"
|
test_dirs="*99"
|
||||||
|
|
||||||
files="vel/512x000.npy"
|
files="vel.npy"
|
||||||
in_files="$files"
|
in_files="$files"
|
||||||
tgt_files="$files"
|
tgt_files="$files"
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ tgt_dir="nonlin"
|
|||||||
train_dirs="*[0-8]"
|
train_dirs="*[0-8]"
|
||||||
val_dirs="*[0-8]9"
|
val_dirs="*[0-8]9"
|
||||||
|
|
||||||
files="vel/512x000.npy"
|
files="vel.npy"
|
||||||
in_files="$files"
|
in_files="$files"
|
||||||
tgt_files="$files"
|
tgt_files="$files"
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ srun m2m.py train \
|
|||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--norms cosmology.vel --augment --crop 100 --pad 42 \
|
--norms cosmology.vel --augment --crop 100 --pad 42 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 \
|
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 --seed $RANDOM \
|
||||||
--cache --div-data
|
--cache --div-data
|
||||||
# --load-state checkpoint.pth \
|
# --load-state checkpoint.pth \
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user