Fix to rename --batches to --batch-size, former kept for BC
This commit is contained in:
parent
7be3153206
commit
4799f8177c
@ -72,7 +72,9 @@ def add_common_args(parser):
|
|||||||
help='allow incompatible keys when loading model states',
|
help='allow incompatible keys when loading model states',
|
||||||
dest='load_state_strict')
|
dest='load_state_strict')
|
||||||
|
|
||||||
parser.add_argument('--batches', type=int, required=True,
|
# somehow I named it "batches" instead of batch_size at first
|
||||||
|
# "batches" is kept for now for backward compatibility
|
||||||
|
parser.add_argument('--batch-size', '--batches', type=int, required=True,
|
||||||
help='mini-batch size, per GPU in training or in total in testing')
|
help='mini-batch size, per GPU in training or in total in testing')
|
||||||
parser.add_argument('--loader-workers', default=-8, type=int,
|
parser.add_argument('--loader-workers', default=-8, type=int,
|
||||||
help='number of subprocesses per data loader. '
|
help='number of subprocesses per data loader. '
|
||||||
@ -174,7 +176,7 @@ def str_list(s):
|
|||||||
|
|
||||||
def set_common_args(args):
|
def set_common_args(args):
|
||||||
if args.loader_workers < 0:
|
if args.loader_workers < 0:
|
||||||
args.loader_workers *= - args.batches
|
args.loader_workers *= - args.batch_size
|
||||||
|
|
||||||
|
|
||||||
def set_train_args(args):
|
def set_train_args(args):
|
||||||
|
@ -34,7 +34,7 @@ def test(args):
|
|||||||
)
|
)
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=args.loader_workers,
|
num_workers=args.loader_workers,
|
||||||
)
|
)
|
||||||
|
@ -81,7 +81,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
div_shuffle_dist=args.div_shuffle_dist)
|
div_shuffle_dist=args.div_shuffle_dist)
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
num_workers=args.loader_workers,
|
num_workers=args.loader_workers,
|
||||||
@ -112,7 +112,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
div_shuffle_dist=args.div_shuffle_dist)
|
div_shuffle_dist=args.div_shuffle_dist)
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
sampler=val_sampler,
|
sampler=val_sampler,
|
||||||
num_workers=args.loader_workers,
|
num_workers=args.loader_workers,
|
||||||
|
@ -26,7 +26,7 @@ m2m.py test \
|
|||||||
--test-tgt-patterns "test/D0-*.npy,test/D1-*.npy" \
|
--test-tgt-patterns "test/D0-*.npy,test/D1-*.npy" \
|
||||||
--in-norms RnD.R0,RnD.R1 --tgt-norms RnD.D0,RnD.D1 \
|
--in-norms RnD.R0,RnD.R1 --tgt-norms RnD.D0,RnD.D1 \
|
||||||
--model model.Net --callback-at . \
|
--model model.Net --callback-at . \
|
||||||
--batches 1 \
|
--batch-size 1 \
|
||||||
--load-state checkpoint.pt
|
--load-state checkpoint.pt
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ srun m2m.py train \
|
|||||||
--val-tgt-patterns "val/D0-*.npy,val/D1-*.npy" \
|
--val-tgt-patterns "val/D0-*.npy,val/D1-*.npy" \
|
||||||
--in-norms RnD.R0,RnD.R1 --tgt-norms RnD.D0,RnD.D1 \
|
--in-norms RnD.R0,RnD.R1 --tgt-norms RnD.D0,RnD.D1 \
|
||||||
--model model.Net --callback-at . \
|
--model model.Net --callback-at . \
|
||||||
--lr 1e-4 --batches 1 \
|
--lr 1e-4 --batch-size 1 \
|
||||||
--epochs 1024 --seed $RANDOM
|
--epochs 1024 --seed $RANDOM
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user