Update slurm script to include adversary model
This commit is contained in:
parent
94ce018cb8
commit
6e06682751
@ -62,20 +62,20 @@ def add_train_args(parser):
|
|||||||
parser.add_argument('--cgan', action='store_true',
|
parser.add_argument('--cgan', action='store_true',
|
||||||
help='enable conditional GAN')
|
help='enable conditional GAN')
|
||||||
|
|
||||||
parser.add_argument('--epochs', default=128, type=int,
|
|
||||||
help='total number of epochs to run')
|
|
||||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||||
help='optimizer from torch.optim')
|
help='optimizer from torch.optim')
|
||||||
parser.add_argument('--lr', default=0.001, type=float,
|
parser.add_argument('--lr', default=0.001, type=float,
|
||||||
help='initial learning rate')
|
help='initial learning rate')
|
||||||
# parser.add_argument('--momentum', default=0.9, type=float,
|
# parser.add_argument('--momentum', default=0.9, type=float,
|
||||||
# help='momentum')
|
# help='momentum')
|
||||||
parser.add_argument('--weight-decay', default=0., type=float,
|
parser.add_argument('--weight-decay', default=0, type=float,
|
||||||
help='weight decay')
|
help='weight decay')
|
||||||
parser.add_argument('--adv-lr', type=float,
|
parser.add_argument('--adv-lr', type=float,
|
||||||
help='initial adversary learning rate')
|
help='initial adversary learning rate')
|
||||||
parser.add_argument('--adv-weight-decay', type=float,
|
parser.add_argument('--adv-weight-decay', type=float,
|
||||||
help='adversary weight decay')
|
help='adversary weight decay')
|
||||||
|
parser.add_argument('--epochs', default=128, type=int,
|
||||||
|
help='total number of epochs to run')
|
||||||
parser.add_argument('--seed', default=42, type=int,
|
parser.add_argument('--seed', default=42, type=int,
|
||||||
help='seed for initializing training')
|
help='seed for initializing training')
|
||||||
|
|
||||||
|
@ -13,11 +13,11 @@
|
|||||||
hostname; pwd; date
|
hostname; pwd; date
|
||||||
|
|
||||||
|
|
||||||
module load gcc python3
|
#module load gcc python3
|
||||||
|
source $HOME/anaconda/bin/activate torch
|
||||||
|
|
||||||
|
|
||||||
export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE
|
export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE
|
||||||
echo OMP_NUM_THREADS = $OMP_NUM_THREADS
|
|
||||||
|
|
||||||
|
|
||||||
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
||||||
@ -35,7 +35,7 @@ tgt_files="$files"
|
|||||||
m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--norms cosmology.dis --crop 256 --pad 42 \
|
--norms cosmology.dis --crop 256 --pad 20 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--load-state best_model.pth \
|
--load-state best_model.pth \
|
||||||
--batches 1 --loader-workers 0 \
|
--batches 1 --loader-workers 0 \
|
||||||
|
@ -14,11 +14,12 @@
|
|||||||
hostname; pwd; date
|
hostname; pwd; date
|
||||||
|
|
||||||
|
|
||||||
module load gcc python3
|
#module load gcc python3
|
||||||
|
source $HOME/anaconda/bin/activate torch
|
||||||
|
|
||||||
|
|
||||||
export MASTER_ADDR=$HOSTNAME
|
export MASTER_ADDR=$HOSTNAME
|
||||||
export MASTER_PORT=8888
|
export MASTER_PORT=60606
|
||||||
|
|
||||||
|
|
||||||
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
||||||
@ -39,9 +40,10 @@ srun m2m.py train \
|
|||||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--norms cosmology.dis --augment --crop 100 --pad 42 \
|
--norms cosmology.dis --augment --crop 128 --pad 20 \
|
||||||
--model VNet \
|
--model VNet --adv-model UNet --cgan \
|
||||||
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 --seed $RANDOM \
|
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||||
|
--epochs 128 --seed $RANDOM \
|
||||||
--cache --div-data
|
--cache --div-data
|
||||||
# --load-state checkpoint.pth \
|
# --load-state checkpoint.pth \
|
||||||
|
|
||||||
|
@ -13,11 +13,11 @@
|
|||||||
hostname; pwd; date
|
hostname; pwd; date
|
||||||
|
|
||||||
|
|
||||||
module load gcc python3
|
#module load gcc python3
|
||||||
|
source $HOME/anaconda/bin/activate torch
|
||||||
|
|
||||||
|
|
||||||
export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE
|
export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE
|
||||||
echo OMP_NUM_THREADS = $OMP_NUM_THREADS
|
|
||||||
|
|
||||||
|
|
||||||
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
||||||
@ -35,7 +35,7 @@ tgt_files="$files"
|
|||||||
m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--norms cosmology.vel --crop 256 --pad 42 \
|
--norms cosmology.vel --crop 256 --pad 20 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--load-state best_model.pth \
|
--load-state best_model.pth \
|
||||||
--batches 1 --loader-workers 0 \
|
--batches 1 --loader-workers 0 \
|
||||||
|
@ -14,11 +14,12 @@
|
|||||||
hostname; pwd; date
|
hostname; pwd; date
|
||||||
|
|
||||||
|
|
||||||
module load gcc python3
|
#module load gcc python3
|
||||||
|
source $HOME/anaconda/bin/activate torch
|
||||||
|
|
||||||
|
|
||||||
export MASTER_ADDR=$HOSTNAME
|
export MASTER_ADDR=$HOSTNAME
|
||||||
export MASTER_PORT=8888
|
export MASTER_PORT=60606
|
||||||
|
|
||||||
|
|
||||||
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
||||||
@ -39,9 +40,10 @@ srun m2m.py train \
|
|||||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--norms cosmology.vel --augment --crop 100 --pad 42 \
|
--norms cosmology.vel --augment --crop 128 --pad 20 \
|
||||||
--model VNet \
|
--model VNet --adv-model UNet --cgan \
|
||||||
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 --seed $RANDOM \
|
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||||
|
--epochs 128 --seed $RANDOM \
|
||||||
--cache --div-data
|
--cache --div-data
|
||||||
# --load-state checkpoint.pth \
|
# --load-state checkpoint.pth \
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user