diff --git a/map2map/state.py b/map2map/state.py index d38b476..842db19 100644 --- a/map2map/state.py +++ b/map2map/state.py @@ -1,9 +1,9 @@ import warnings -from pprint import pprint +from pprint import pformat def load_model_state_dict(model, state_dict, strict=True): bad_keys = model.load_state_dict(state_dict, strict) if bad_keys.missing_keys or bad_keys.unexpected_keys: - warnings.warn(pprint(repr(bad_keys))) + warnings.warn(pformat(bad_keys)) diff --git a/map2map/train.py b/map2map/train.py index 1a84c45..981b421 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -47,7 +47,6 @@ def node_worker(args): node = int(os.environ['SLURM_NODEID']) if node == 0: pprint(vars(args)) - args.node = node spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node)