Add plt.close to avoid matplotlib warning

This commit is contained in:
Yin Li 2020-07-28 07:10:46 -07:00
parent a886e53c54
commit 0d41bdae26

View File

@ -12,6 +12,8 @@ from matplotlib.cm import ScalarMappable
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
"""Plot slices of fields of more than 2 spatial dimensions.
"""
plt.close('all')
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
else field for field in fields]