diff --git a/map2map/args.py b/map2map/args.py index 3d79f46..04fe23e 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -62,20 +62,20 @@ def add_train_args(parser): parser.add_argument('--cgan', action='store_true', help='enable conditional GAN') - parser.add_argument('--epochs', default=128, type=int, - help='total number of epochs to run') parser.add_argument('--optimizer', default='Adam', type=str, help='optimizer from torch.optim') parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate') # parser.add_argument('--momentum', default=0.9, type=float, # help='momentum') - parser.add_argument('--weight-decay', default=0., type=float, + parser.add_argument('--weight-decay', default=0, type=float, help='weight decay') parser.add_argument('--adv-lr', type=float, help='initial adversary learning rate') parser.add_argument('--adv-weight-decay', type=float, help='adversary weight decay') + parser.add_argument('--epochs', default=128, type=int, + help='total number of epochs to run') parser.add_argument('--seed', default=42, type=int, help='seed for initializing training') diff --git a/scripts/dis2dis-test.slurm b/scripts/dis2dis-test.slurm index 60b0b5c..7c2361c 100644 --- a/scripts/dis2dis-test.slurm +++ b/scripts/dis2dis-test.slurm @@ -13,11 +13,11 @@ hostname; pwd; date -module load gcc python3 +#module load gcc python3 +source $HOME/anaconda/bin/activate torch export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE -echo OMP_NUM_THREADS = $OMP_NUM_THREADS data_root_dir="/mnt/ceph/users/yinli/Quijote" @@ -35,7 +35,7 @@ tgt_files="$files" m2m.py test \ --test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \ --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ - --norms cosmology.dis --crop 256 --pad 42 \ + --norms cosmology.dis --crop 256 --pad 20 \ --model VNet \ --load-state best_model.pth \ --batches 1 --loader-workers 0 \ diff --git a/scripts/dis2dis.slurm b/scripts/dis2dis.slurm index fa517b7..e48cbcc 100644 --- a/scripts/dis2dis.slurm +++ b/scripts/dis2dis.slurm @@ -14,11 +14,12 @@ hostname; pwd; date -module load gcc python3 +#module load gcc python3 +source $HOME/anaconda/bin/activate torch export MASTER_ADDR=$HOSTNAME -export MASTER_PORT=8888 +export MASTER_PORT=60606 data_root_dir="/mnt/ceph/users/yinli/Quijote" @@ -39,9 +40,10 @@ srun m2m.py train \ --train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ - --norms cosmology.dis --augment --crop 100 --pad 42 \ - --model VNet \ - --epochs 128 --lr 0.001 --batches 1 --loader-workers 0 --seed $RANDOM \ + --norms cosmology.dis --augment --crop 128 --pad 20 \ + --model VNet --adv-model UNet --cgan \ + --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ + --epochs 128 --seed $RANDOM \ --cache --div-data # --load-state checkpoint.pth \ diff --git a/scripts/vel2vel-test.slurm b/scripts/vel2vel-test.slurm index ee19f32..e9feee9 100644 --- a/scripts/vel2vel-test.slurm +++ b/scripts/vel2vel-test.slurm @@ -13,11 +13,11 @@ hostname; pwd; date -module load gcc python3 +#module load gcc python3 +source $HOME/anaconda/bin/activate torch export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE -echo OMP_NUM_THREADS = $OMP_NUM_THREADS data_root_dir="/mnt/ceph/users/yinli/Quijote" @@ -35,7 +35,7 @@ tgt_files="$files" m2m.py test \ --test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \ --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ - --norms cosmology.vel --crop 256 --pad 42 \ + --norms cosmology.vel --crop 256 --pad 20 \ --model VNet \ --load-state best_model.pth \ --batches 1 --loader-workers 0 \ diff --git a/scripts/vel2vel.slurm b/scripts/vel2vel.slurm index 7591c35..da088c2 100644 --- a/scripts/vel2vel.slurm +++ b/scripts/vel2vel.slurm @@ -14,11 +14,12 @@ hostname; pwd; date -module load gcc python3 +#module load gcc python3 +source $HOME/anaconda/bin/activate torch export MASTER_ADDR=$HOSTNAME -export MASTER_PORT=8888 +export MASTER_PORT=60606 data_root_dir="/mnt/ceph/users/yinli/Quijote" @@ -39,9 +40,10 @@ srun m2m.py train \ --train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ - --norms cosmology.vel --augment --crop 100 --pad 42 \ - --model VNet \ - --epochs 128 --lr 0.001 --batches 1 --loader-workers 0 --seed $RANDOM \ + --norms cosmology.vel --augment --crop 128 --pad 20 \ + --model VNet --adv-model UNet --cgan \ + --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ + --epochs 128 --seed $RANDOM \ --cache --div-data # --load-state checkpoint.pth \