diff --git a/map2map/args.py b/map2map/args.py index 5b1e8e3..ad440c7 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -72,7 +72,9 @@ def add_common_args(parser): help='allow incompatible keys when loading model states', dest='load_state_strict') - parser.add_argument('--batches', type=int, required=True, + # somehow I named it "batches" instead of batch_size at first + # "batches" is kept for now for backward compatibility + parser.add_argument('--batch-size', '--batches', type=int, required=True, help='mini-batch size, per GPU in training or in total in testing') parser.add_argument('--loader-workers', default=-8, type=int, help='number of subprocesses per data loader. ' @@ -174,7 +176,7 @@ def str_list(s): def set_common_args(args): if args.loader_workers < 0: - args.loader_workers *= - args.batches + args.loader_workers *= - args.batch_size def set_train_args(args): diff --git a/map2map/test.py b/map2map/test.py index 46ce339..dc87486 100644 --- a/map2map/test.py +++ b/map2map/test.py @@ -34,7 +34,7 @@ def test(args): ) test_loader = DataLoader( test_dataset, - batch_size=args.batches, + batch_size=args.batch_size, shuffle=False, num_workers=args.loader_workers, ) diff --git a/map2map/train.py b/map2map/train.py index bc23e3f..7bf9c73 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -81,7 +81,7 @@ def gpu_worker(local_rank, node, args): div_shuffle_dist=args.div_shuffle_dist) train_loader = DataLoader( train_dataset, - batch_size=args.batches, + batch_size=args.batch_size, shuffle=False, sampler=train_sampler, num_workers=args.loader_workers, @@ -112,7 +112,7 @@ def gpu_worker(local_rank, node, args): div_shuffle_dist=args.div_shuffle_dist) val_loader = DataLoader( val_dataset, - batch_size=args.batches, + batch_size=args.batch_size, shuffle=False, sampler=val_sampler, num_workers=args.loader_workers, diff --git a/scripts/example-test.slurm b/scripts/example-test.slurm index a118f0e..1036450 100644 --- a/scripts/example-test.slurm +++ b/scripts/example-test.slurm @@ -26,7 +26,7 @@ m2m.py test \ --test-tgt-patterns "test/D0-*.npy,test/D1-*.npy" \ --in-norms RnD.R0,RnD.R1 --tgt-norms RnD.D0,RnD.D1 \ --model model.Net --callback-at . \ - --batches 1 \ + --batch-size 1 \ --load-state checkpoint.pt diff --git a/scripts/example-train.slurm b/scripts/example-train.slurm index 580616a..3cd4049 100644 --- a/scripts/example-train.slurm +++ b/scripts/example-train.slurm @@ -36,7 +36,7 @@ srun m2m.py train \ --val-tgt-patterns "val/D0-*.npy,val/D1-*.npy" \ --in-norms RnD.R0,RnD.R1 --tgt-norms RnD.D0,RnD.D1 \ --model model.Net --callback-at . \ - --lr 1e-4 --batches 1 \ + --lr 1e-4 --batch-size 1 \ --epochs 1024 --seed $RANDOM