Fix unstable training by limiting pytorch version to 1.1
This commit is contained in:
parent
437126e296
commit
11c9caa1e2
@ -48,7 +48,7 @@ def test(args):
|
|||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
|
|
||||||
print('sample {} loss: {}'.format(i, loss))
|
print('sample {} loss: {}'.format(i, loss.item()))
|
||||||
|
|
||||||
if args.norms is not None:
|
if args.norms is not None:
|
||||||
norm = test_dataset.norms[0] # FIXME
|
norm = test_dataset.norms[0] # FIXME
|
||||||
|
@ -48,7 +48,8 @@ def gpu_worker(local_rank, args):
|
|||||||
norms=args.norms,
|
norms=args.norms,
|
||||||
pad_or_crop=args.pad_or_crop,
|
pad_or_crop=args.pad_or_crop,
|
||||||
)
|
)
|
||||||
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||||
|
train_sampler = DistributedSampler(train_dataset)
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batches,
|
||||||
@ -65,7 +66,8 @@ def gpu_worker(local_rank, args):
|
|||||||
norms=args.norms,
|
norms=args.norms,
|
||||||
pad_or_crop=args.pad_or_crop,
|
pad_or_crop=args.pad_or_crop,
|
||||||
)
|
)
|
||||||
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||||
|
val_sampler = DistributedSampler(val_dataset)
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batches,
|
||||||
@ -112,9 +114,9 @@ def gpu_worker(local_rank, args):
|
|||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
args.logger = SummaryWriter()
|
args.logger = SummaryWriter()
|
||||||
hparam = {k: v if isinstance(v, (int, float, str, bool, torch.Tensor))
|
#hparam = {k: v if isinstance(v, (int, float, str, bool, torch.Tensor))
|
||||||
else str(v) for k, v in vars(args).items()}
|
# else str(v) for k, v in vars(args).items()}
|
||||||
args.logger.add_hparams(hparam_dict=hparam, metric_dict={})
|
#args.logger.add_hparams(hparam_dict=hparam, metric_dict={})
|
||||||
|
|
||||||
for epoch in range(args.start_epoch, args.epochs):
|
for epoch in range(args.start_epoch, args.epochs):
|
||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
@ -125,7 +127,8 @@ def gpu_worker(local_rank, args):
|
|||||||
scheduler.step(val_loss)
|
scheduler.step(val_loss)
|
||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
args.logger.close()
|
print(end='', flush=True)
|
||||||
|
args.logger.flush()
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
'epoch': epoch + 1,
|
'epoch': epoch + 1,
|
||||||
|
@ -15,10 +15,7 @@
|
|||||||
hostname; pwd; date
|
hostname; pwd; date
|
||||||
|
|
||||||
|
|
||||||
module load gcc openmpi2
|
module load gcc python3
|
||||||
module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1
|
|
||||||
|
|
||||||
source $HOME/anaconda3/bin/activate torch
|
|
||||||
|
|
||||||
|
|
||||||
export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE
|
export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE
|
||||||
@ -37,7 +34,7 @@ in_files="$files"
|
|||||||
tgt_files="$files"
|
tgt_files="$files"
|
||||||
|
|
||||||
|
|
||||||
srun m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--in-channels 3 --out-channels 3 --norms cosmology.dis \
|
--in-channels 3 --out-channels 3 --norms cosmology.dis \
|
||||||
|
@ -17,10 +17,7 @@
|
|||||||
hostname; pwd; date
|
hostname; pwd; date
|
||||||
|
|
||||||
|
|
||||||
module load gcc openmpi2
|
module load gcc python3
|
||||||
module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1
|
|
||||||
|
|
||||||
source $HOME/anaconda3/bin/activate torch
|
|
||||||
|
|
||||||
|
|
||||||
export MASTER_ADDR=$HOSTNAME
|
export MASTER_ADDR=$HOSTNAME
|
||||||
|
@ -15,10 +15,7 @@
|
|||||||
hostname; pwd; date
|
hostname; pwd; date
|
||||||
|
|
||||||
|
|
||||||
module load gcc openmpi2
|
module load gcc python3
|
||||||
module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1
|
|
||||||
|
|
||||||
source $HOME/anaconda3/bin/activate torch
|
|
||||||
|
|
||||||
|
|
||||||
export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE
|
export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE
|
||||||
@ -37,7 +34,7 @@ in_files="$files"
|
|||||||
tgt_files="$files"
|
tgt_files="$files"
|
||||||
|
|
||||||
|
|
||||||
srun m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--in-channels 3 --out-channels 3 --norms cosmology.vel \
|
--in-channels 3 --out-channels 3 --norms cosmology.vel \
|
||||||
|
@ -17,10 +17,7 @@
|
|||||||
hostname; pwd; date
|
hostname; pwd; date
|
||||||
|
|
||||||
|
|
||||||
module load gcc openmpi2
|
module load gcc python3
|
||||||
module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1
|
|
||||||
|
|
||||||
source $HOME/anaconda3/bin/activate torch
|
|
||||||
|
|
||||||
|
|
||||||
export MASTER_ADDR=$HOSTNAME
|
export MASTER_ADDR=$HOSTNAME
|
||||||
|
4
setup.py
4
setup.py
@ -5,11 +5,11 @@ setup(
|
|||||||
name='map2map',
|
name='map2map',
|
||||||
version='0.0',
|
version='0.0',
|
||||||
description='Neural network emulators to transform field data',
|
description='Neural network emulators to transform field data',
|
||||||
author='Yin Li',
|
author='Yin Li et al.',
|
||||||
author_email='eelregit@gmail.com',
|
author_email='eelregit@gmail.com',
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'torch',
|
'torch==1.1',
|
||||||
'numpy',
|
'numpy',
|
||||||
'scipy',
|
'scipy',
|
||||||
'tensorboard',
|
'tensorboard',
|
||||||
|
Loading…
Reference in New Issue
Block a user