Fix to rename --batches to --batch-size, former kept for BC
This commit is contained in:
parent
7be3153206
commit
4799f8177c
@ -72,7 +72,9 @@ def add_common_args(parser):
|
||||
help='allow incompatible keys when loading model states',
|
||||
dest='load_state_strict')
|
||||
|
||||
parser.add_argument('--batches', type=int, required=True,
|
||||
# somehow I named it "batches" instead of batch_size at first
|
||||
# "batches" is kept for now for backward compatibility
|
||||
parser.add_argument('--batch-size', '--batches', type=int, required=True,
|
||||
help='mini-batch size, per GPU in training or in total in testing')
|
||||
parser.add_argument('--loader-workers', default=-8, type=int,
|
||||
help='number of subprocesses per data loader. '
|
||||
@ -174,7 +176,7 @@ def str_list(s):
|
||||
|
||||
def set_common_args(args):
|
||||
if args.loader_workers < 0:
|
||||
args.loader_workers *= - args.batches
|
||||
args.loader_workers *= - args.batch_size
|
||||
|
||||
|
||||
def set_train_args(args):
|
||||
|
@ -34,7 +34,7 @@ def test(args):
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batches,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.loader_workers,
|
||||
)
|
||||
|
@ -81,7 +81,7 @@ def gpu_worker(local_rank, node, args):
|
||||
div_shuffle_dist=args.div_shuffle_dist)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batches,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
sampler=train_sampler,
|
||||
num_workers=args.loader_workers,
|
||||
@ -112,7 +112,7 @@ def gpu_worker(local_rank, node, args):
|
||||
div_shuffle_dist=args.div_shuffle_dist)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batches,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
sampler=val_sampler,
|
||||
num_workers=args.loader_workers,
|
||||
|
@ -26,7 +26,7 @@ m2m.py test \
|
||||
--test-tgt-patterns "test/D0-*.npy,test/D1-*.npy" \
|
||||
--in-norms RnD.R0,RnD.R1 --tgt-norms RnD.D0,RnD.D1 \
|
||||
--model model.Net --callback-at . \
|
||||
--batches 1 \
|
||||
--batch-size 1 \
|
||||
--load-state checkpoint.pt
|
||||
|
||||
|
||||
|
@ -36,7 +36,7 @@ srun m2m.py train \
|
||||
--val-tgt-patterns "val/D0-*.npy,val/D1-*.npy" \
|
||||
--in-norms RnD.R0,RnD.R1 --tgt-norms RnD.D0,RnD.D1 \
|
||||
--model model.Net --callback-at . \
|
||||
--lr 1e-4 --batches 1 \
|
||||
--lr 1e-4 --batch-size 1 \
|
||||
--epochs 1024 --seed $RANDOM
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user