From 88bfd11594f07a6c0e1d927879aea59fafe068c3 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Sat, 30 Nov 2019 15:32:45 -0500 Subject: [PATCH] Add training --- map2map/__init__.py | 0 map2map/args.py | 76 +++++++++++++ map2map/data/__init__.py | 1 + map2map/data/fields.py | 101 +++++++++++++++++ map2map/data/norms/__init__.py | 1 + map2map/data/norms/cosmology.py | 56 +++++++++ map2map/main.py | 13 +++ map2map/models/__init__.py | 2 + map2map/models/conv.py | 68 +++++++++++ map2map/models/unet.py | 53 +++++++++ map2map/test.py | 8 ++ map2map/train.py | 194 ++++++++++++++++++++++++++++++++ map2map/utils/__init__.py | 0 scripts/dis2dis.slurm | 53 +++++++++ scripts/m2m.py | 5 + scripts/vel2vel.slurm | 53 +++++++++ setup.py | 20 ++++ 17 files changed, 704 insertions(+) create mode 100644 map2map/__init__.py create mode 100644 map2map/args.py create mode 100644 map2map/data/__init__.py create mode 100644 map2map/data/fields.py create mode 100644 map2map/data/norms/__init__.py create mode 100644 map2map/data/norms/cosmology.py create mode 100644 map2map/main.py create mode 100644 map2map/models/__init__.py create mode 100644 map2map/models/conv.py create mode 100644 map2map/models/unet.py create mode 100644 map2map/test.py create mode 100644 map2map/train.py create mode 100644 map2map/utils/__init__.py create mode 100644 scripts/dis2dis.slurm create mode 100644 scripts/m2m.py create mode 100644 scripts/vel2vel.slurm create mode 100644 setup.py diff --git a/map2map/__init__.py b/map2map/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/map2map/args.py b/map2map/args.py new file mode 100644 index 0000000..3987100 --- /dev/null +++ b/map2map/args.py @@ -0,0 +1,76 @@ +from argparse import ArgumentParser + + +def get_args(): + parser = ArgumentParser(description='Transform field(s) to field(s)') + subparsers = parser.add_subparsers(title='modes', dest='mode', required=True) + train_parser = subparsers.add_parser('train') + test_parser = subparsers.add_parser('test') + + add_train_args(train_parser) + add_test_args(test_parser) + + args = parser.parse_args() + + return args + + +def add_common_args(parser): + parser.add_argument('--in-channels', type=int, required=True, + help='number of input channels') + parser.add_argument('--out-channels', type=int, required=True, + help='number of output or target channels') + parser.add_argument('--norms', type=str_list, help='comma-sep. list ' + 'of normalization functions from map2map.data.norms') + parser.add_argument('--criterion', default='MSELoss', + help='model criterion from torch.nn') + parser.add_argument('--load-state', default='', type=str, + help='path to load model, optimizer, rng, etc.') + + +def add_train_args(parser): + add_common_args(parser) + + parser.add_argument('--train-in-patterns', type=str_list, required=True, + help='comma-sep. list of glob patterns for training input data') + parser.add_argument('--train-tgt-patterns', type=str_list, required=True, + help='comma-sep. list of glob patterns for training target data') + parser.add_argument('--val-in-patterns', type=str_list, required=True, + help='comma-sep. list of glob patterns for validation input data') + parser.add_argument('--val-tgt-patterns', type=str_list, required=True, + help='comma-sep. list of glob patterns for validation target data') + parser.add_argument('--epochs', default=128, type=int, + help='total number of epochs to run') + parser.add_argument('--batches-per-gpu', default=8, type=int, + help='mini-batch size per GPU') + parser.add_argument('--loader-workers-per-gpu', default=4, type=int, + help='number of data loading workers per GPU') + parser.add_argument('--augment', action='store_true', + help='enable training data augmentation') + parser.add_argument('--optimizer', default='Adam', + help='optimizer from torch.optim') + parser.add_argument('--lr', default=0.001, type=float, + help='initial learning rate') +# parser.add_argument('--momentum', default=0.9, type=float, +# help='momentum') +# parser.add_argument('--weight-decay', default=1e-4, type=float, +# help='weight decay') + parser.add_argument('--dist-backend', default='nccl', type=str, + choices=['gloo', 'nccl'], help='distributed backend') + parser.add_argument('--seed', default=42, type=int, + help='seed for initializing training') + parser.add_argument('--log-interval', default=20, type=int, + help='interval between logging training loss') + + +def add_test_args(parser): + add_common_args(parser) + + parser.add_argument('--test-in-patterns', type=str_list, required=True, + help='comma-sep. list of glob patterns for test input data') + parser.add_argument('--test-tgt-patterns', type=str_list, required=True, + help='comma-sep. list of glob patterns for test target data') + + +def str_list(s): + return s.split(',') diff --git a/map2map/data/__init__.py b/map2map/data/__init__.py new file mode 100644 index 0000000..0fe6e8a --- /dev/null +++ b/map2map/data/__init__.py @@ -0,0 +1 @@ +from .fields import FieldDataset diff --git a/map2map/data/fields.py b/map2map/data/fields.py new file mode 100644 index 0000000..e8fbc25 --- /dev/null +++ b/map2map/data/fields.py @@ -0,0 +1,101 @@ +from glob import glob +import numpy as np +import torch +from torch.utils.data import Dataset + +from . import norms + + +class FieldDataset(Dataset): + """Dataset of lists of fields. + + `in_patterns` is a list of glob patterns for the input fields. + For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`. + Likewise `tgt_patterns` is for target fields. + Input and target samples of all fields are matched by sorting the globbed files. + + Data augmentations are supported for scalar and vector fields. + + `normalize` can be a list of callables to normalize each field. + """ + def __init__(self, in_patterns, tgt_patterns, augment=False, + normalize=None, **kwargs): + in_file_lists = [sorted(glob(p)) for p in in_patterns] + self.in_files = list(zip(* in_file_lists)) + + tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns] + self.tgt_files = list(zip(* tgt_file_lists)) + + assert len(self.in_files) == len(self.tgt_files), \ + 'input and target sample sizes do not match' + + self.augment = augment + + self.normalize = normalize + if self.normalize is not None: + assert len(in_patterns) == len(self.normalize), \ + 'numbers of normalization callables and input fields do not match' + +# self.__dict__.update(kwargs) + + def __len__(self): + return len(self.in_files) + + def __getitem__(self, idx): + in_fields = [torch.from_numpy(np.load(f)).to(torch.float32) + for f in self.in_files[idx]] + tgt_fields = [torch.from_numpy(np.load(f)).to(torch.float32) + for f in self.tgt_files[idx]] + + if self.augment: + flip_axes = torch.randint(2, (3,), dtype=torch.bool) + flip_axes = torch.arange(3)[flip_axes] + + flip3d(in_fields, flip_axes) + flip3d(tgt_fields, flip_axes) + + perm_axes = torch.randperm(3) + + perm3d(in_fields, perm_axes) + perm3d(tgt_fields, perm_axes) + + if self.normalize is not None: + def get_norm(path): + path = path.split('.') + norm = norms + while path: + norm = norm.__dict__[path.pop(0)] + return norm + + for norm, ifield, tfield in zip(self.normalize, in_fields, tgt_fields): + if isinstance(norm, str): + norm = get_norm(norm) + + norm(ifield) + norm(tfield) + + in_fields = torch.cat(in_fields, dim=0) + tgt_fields = torch.cat(tgt_fields, dim=0) + + return in_fields, tgt_fields + + +def flip3d(fields, axes): + for i, x in enumerate(fields): + if x.size(0) == 3: # flip vector components + x[axes] = - x[axes] + + axes = (1 + axes).tolist() + x = torch.flip(x, axes) + + fields[i] = x + +def perm3d(fields, axes): + for i, x in enumerate(fields): + if x.size(0) == 3: # permutate vector components + x = x[axes] + + axes = [0] + (1 + axes).tolist() + x = x.permute(axes) + + fields[i] = x diff --git a/map2map/data/norms/__init__.py b/map2map/data/norms/__init__.py new file mode 100644 index 0000000..167c422 --- /dev/null +++ b/map2map/data/norms/__init__.py @@ -0,0 +1 @@ +from . import cosmology diff --git a/map2map/data/norms/cosmology.py b/map2map/data/norms/cosmology.py new file mode 100644 index 0000000..ee1cfa4 --- /dev/null +++ b/map2map/data/norms/cosmology.py @@ -0,0 +1,56 @@ +import numpy as np +from scipy.special import hyp2f1 + + +def dis(x, undo=False): + z = 0 # FIXME + dis_norm = 6 * D(z) # [Mpc/h] + + if not undo: + dis_norm = 1 / dis_norm + + x *= dis_norm + +def vel(x, undo=False): + z = 0 # FIXME + vel_norm = 6 * D(z) * H(z) * f(z) / (1 + z) # [km/s] + + if not undo: + vel_norm = 1 / vel_norm + + x *= vel_norm + +def den(x, undo=False): + raise NotImplementedError + z = 0 # FIXME + den_norm = 0 # FIXME + + if not undo: + den_norm = 1 / den_norm + + x *= den_norm + + +def D(z, Om=0.31): + """linear growth function for flat LambdaCDM, normalized to 1 at redshift zero + """ + OL = 1 - Om + a = 1 / (1+z) + return a * hyp2f1(1, 1/3, 11/6, - OL * a**3 / Om) \ + / hyp2f1(1, 1/3, 11/6, - OL / Om) + +def f(z, Om=0.31): + """linear growth rate for flat LambdaCDM + """ + OL = 1 - Om + a = 1 / (1+z) + aa3 = OL * a**3 / Om + return 1 - 6/11*aa3 * hyp2f1(2, 4/3, 17/6, -aa3) \ + / hyp2f1(1, 1/3, 11/6, -aa3) + +def H(z, Om=0.31): + """Hubble in [h km/s/Mpc] for flat LambdaCDM + """ + OL = 1 - Om + a = 1 / (1+z) + return 100 * np.sqrt(Om / a**3 + OL) diff --git a/map2map/main.py b/map2map/main.py new file mode 100644 index 0000000..b9615a2 --- /dev/null +++ b/map2map/main.py @@ -0,0 +1,13 @@ +from .args import get_args +from . import train +from . import test + + +def main(): + + args = get_args() + + if args.mode == 'train': + train.node_worker(args) + elif args.mode == 'test': + pass diff --git a/map2map/models/__init__.py b/map2map/models/__init__.py new file mode 100644 index 0000000..6dcdd73 --- /dev/null +++ b/map2map/models/__init__.py @@ -0,0 +1,2 @@ +from .unet import UNet +from .conv import narrow_like diff --git a/map2map/models/conv.py b/map2map/models/conv.py new file mode 100644 index 0000000..398484a --- /dev/null +++ b/map2map/models/conv.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn + + +class ConvBlock(nn.Module): + """Convolution blocks of the form specified by `seq`. + """ + def __init__(self, in_channels, out_channels, mid_channels=None, seq='CBAC'): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + if mid_channels is None: + self.mid_channels = max(in_channels, out_channels) + + self.bn_channels = in_channels + self.idx_conv = 0 + self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']]) + + layers = [self._get_layer(l) for l in seq] + + self.convs = nn.Sequential(*layers) + + def _get_layer(self, l): + if l == 'U': + in_channels, out_channels = self._setup_conv() + return nn.ConvTranspose3d(in_channels, out_channels, 2, stride=2) + elif l == 'D': + in_channels, out_channels = self._setup_conv() + return nn.Conv3d(in_channels, out_channels, 2, stride=2) + elif l == 'C': + in_channels, out_channels = self._setup_conv() + return nn.Conv3d(in_channels, out_channels, 3) + elif l == 'B': + return nn.BatchNorm3d(self.bn_channels) + elif l == 'A': + return nn.LeakyReLU(inplace=True) + else: + raise NotImplementedError('layer type {} not supported'.format(l)) + + def _setup_conv(self): + self.idx_conv += 1 + + in_channels = out_channels = self.mid_channels + if self.idx_conv == 1: + in_channels = self.in_channels + if self.idx_conv == self.num_conv: + out_channels = self.out_channels + + self.bn_channels = out_channels + + return in_channels, out_channels + + def forward(self, x): + return self.convs(x) + + +def narrow_like(a, b): + """Narrow a to be like b. + + Try to be symmetric but cut more on the right for odd difference, + consistent with the downsampling. + """ + for dim in range(2, 5): + width = a.size(dim) - b.size(dim) + half_width = width // 2 + a = a.narrow(dim, half_width, a.size(dim) - width) + return a diff --git a/map2map/models/unet.py b/map2map/models/unet.py new file mode 100644 index 0000000..9831419 --- /dev/null +++ b/map2map/models/unet.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn + +from .conv import ConvBlock, narrow_like + + +class DownBlock(ConvBlock): + def __init__(self, in_channels, out_channels, seq='BADBA'): + super().__init__(in_channels, out_channels, seq=seq) + +class UpBlock(ConvBlock): + def __init__(self, in_channels, out_channels, seq='BAUBA'): + super().__init__(in_channels, out_channels, seq=seq) + + +class UNet(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + + self.conv_0l = ConvBlock(in_channels, 64, seq='CAC') + self.down_0l = DownBlock(64, 64) + self.conv_1l = ConvBlock(64, 64) + self.down_1l = DownBlock(64, 64) + + self.conv_2c = ConvBlock(64, 64) + + self.up_1r = UpBlock(64, 64) + self.conv_1r = ConvBlock(128, 64) + self.up_0r = UpBlock(64, 64) + self.conv_0r = ConvBlock(128, out_channels, seq='CAC') + + def forward(self, x): + y0 = self.conv_0l(x) + x = self.down_0l(y0) + + y1 = self.conv_1l(x) + x = self.down_1l(y1) + + x = self.conv_2c(x) + + x = self.up_1r(x) + y1 = narrow_like(y1, x) + x = torch.cat([y1, x], dim=1) + del y1 + x = self.conv_1r(x) + + x = self.up_0r(x) + y0 = narrow_like(y0, x) + x = torch.cat([y0, x], dim=1) + del y0 + x = self.conv_0r(x) + + return x diff --git a/map2map/test.py b/map2map/test.py new file mode 100644 index 0000000..2eab6bd --- /dev/null +++ b/map2map/test.py @@ -0,0 +1,8 @@ +import os + +import torch +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +from .data import FieldDataset +from .models import UNet, narrow_like diff --git a/map2map/train.py b/map2map/train.py new file mode 100644 index 0000000..e5d0bd3 --- /dev/null +++ b/map2map/train.py @@ -0,0 +1,194 @@ +import os +import shutil + +import torch +from torch.multiprocessing import spawn +from torch.distributed import init_process_group, destroy_process_group, all_reduce +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +from .data import FieldDataset +from .models import UNet, narrow_like + + +def node_worker(args): + torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker? + #torch.backends.cudnn.deterministic = True # NOTE: test perf + + args.gpus_per_node = torch.cuda.device_count() + args.nodes = int(os.environ['SLURM_JOB_NUM_NODES']) + args.world_size = args.gpus_per_node * args.nodes + + node = int(os.environ['SLURM_NODEID']) + if node == 0: + print(args) + args.node = node + + spawn(gpu_worker, args=(args,), nprocs=args.gpus_per_node) + + +def gpu_worker(local_rank, args): + args.device = torch.device('cuda', local_rank) + torch.cuda.device(args.device) + + args.rank = args.gpus_per_node * args.node + local_rank + + init_process_group( + backend=args.dist_backend, + init_method='env://', + world_size=args.world_size, + rank=args.rank + ) + + train_dataset = FieldDataset( + in_patterns=args.train_in_patterns, + tgt_patterns=args.train_tgt_patterns, + augment=args.augment, + normalize=args.norms, + ) + train_sampler = DistributedSampler(train_dataset, shuffle=True) + train_loader = DataLoader( + train_dataset, + batch_size=args.batches_per_gpu, + shuffle=False, + sampler=train_sampler, + num_workers=args.loader_workers_per_gpu, + pin_memory=True + ) + + val_dataset = FieldDataset( + in_patterns=args.val_in_patterns, + tgt_patterns=args.val_tgt_patterns, + augment=False, + normalize=args.norms, + ) + val_sampler = DistributedSampler(val_dataset, shuffle=False) + val_loader = DataLoader( + val_dataset, + batch_size=args.batches_per_gpu, + shuffle=False, + sampler=val_sampler, + num_workers=args.loader_workers_per_gpu, + pin_memory=True + ) + + model = UNet(args.in_channels, args.out_channels) + model.to(args.device) + model = DistributedDataParallel(model, device_ids=[args.device]) + + criterion = torch.nn.__dict__[args.criterion]() + criterion.to(args.device) + + optimizer = torch.optim.__dict__[args.optimizer]( + model.parameters(), + lr=args.lr, + #momentum=args.momentum, + #weight_decay=args.weight_decay + ) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + + if args.load_state: + checkpoint = torch.load(args.load_state, map_location=args.device) + args.start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + scheduler.load_state_dict(checkpoint['scheduler']) + torch.set_rng_state(checkpoint['rng'].cpu()) # move rng state back + if args.rank == 0: + min_loss = checkpoint['min_loss'] + print('checkpoint of epoch {} loaded from {}'.format( + checkpoint['epoch'], args.load_state)) + del checkpoint + else: + args.start_epoch = 0 + if args.rank == 0: + min_loss = None + + torch.backends.cudnn.benchmark = True # NOTE: test perf + + if args.rank == 0: + args.logger = SummaryWriter() + hparam = {k: v if isinstance(v, (int, float, str, bool, torch.Tensor)) + else str(v) for k, v in vars(args).items()} + args.logger.add_hparams(hparam_dict=hparam, metric_dict={}) + + for epoch in range(args.start_epoch, args.epochs): + train_sampler.set_epoch(epoch) + train(epoch, train_loader, model, criterion, optimizer, args) + + val_loss = validate(epoch, val_loader, model, criterion, args) + + scheduler.step(val_loss) + + if args.rank == 0: + args.logger.close() + + checkpoint = { + 'epoch': epoch + 1, + 'model': model.state_dict(), + 'optimizer' : optimizer.state_dict(), + 'scheduler' : scheduler.state_dict(), + 'rng' : torch.get_rng_state(), + 'min_loss': min_loss, + } + filename='checkpoint.pth' + torch.save(checkpoint, filename) + del checkpoint + + if min_loss is None or val_loss < min_loss: + min_loss = val_loss + shutil.copyfile(filename, 'best_model.pth') + + destroy_process_group() + + +def train(epoch, loader, model, criterion, optimizer, args): + model.train() + + for i, (input, target) in enumerate(loader): + input = input.to(args.device, non_blocking=True) + target = target.to(args.device, non_blocking=True) + + output = model(input) + target = narrow_like(target, output) + + loss = criterion(output, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + batch = epoch * len(loader) + i + if batch % args.log_interval == 0: + all_reduce(loss) + loss /= args.world_size + if args.rank == 0: + args.logger.add_scalar('loss/train', loss.item(), global_step=batch) + +# f'max GPU mem: {torch.cuda.max_memory_allocated()} allocated, {torch.cuda.max_memory_cached()} cached') + +def validate(epoch, loader, model, criterion, args): + model.eval() + + loss = 0 + + with torch.no_grad(): + for i, (input, target) in enumerate(loader): + input = input.to(args.device, non_blocking=True) + target = target.to(args.device, non_blocking=True) + + output = model(input) + target = narrow_like(target, output) + + loss += criterion(output, target) + + all_reduce(loss) + loss /= len(loader) * args.world_size + if args.rank == 0: + args.logger.add_scalar('loss/val', loss.item(), global_step=epoch) + +# f'max GPU mem: {torch.cuda.max_memory_allocated()} allocated, {torch.cuda.max_memory_cached()} cached') + + return loss.item() diff --git a/map2map/utils/__init__.py b/map2map/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/dis2dis.slurm b/scripts/dis2dis.slurm new file mode 100644 index 0000000..36cf91e --- /dev/null +++ b/scripts/dis2dis.slurm @@ -0,0 +1,53 @@ +#!/bin/bash + +#SBATCH --job-name=dis2dis +#SBATCH --dependency=singleton +#SBATCH --output=%x-%j.out +#SBATCH --error=%x-%j.err + +#SBATCH --partition=gpu +#SBATCH --gres=gpu:v100-32gb:4 + +#SBATCH --exclusive +#SBATCH --nodes=2 +#SBATCH --mem=0 +#SBATCH --time=2-00:00:00 + + +hostname; pwd; date + + +module load gcc openmpi2 +module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1 + +source $HOME/anaconda3/bin/activate torch + + +export MASTER_ADDR=$HOSTNAME +export MASTER_PORT=8888 + + +data_root_dir="/mnt/ceph/users/yinli/Quijote" + +in_dir="linear" +tgt_dir="nonlin" + +train_dirs="*[1-9]" +val_dirs="*[1-9]0" + +files="dis/128x???.npy" +in_files="$files" +tgt_files="$files" + + +srun m2m.py train \ + --train-in-patterns "$data_root_dir/$in_dir/$train_dirs/$in_files" \ + --train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \ + --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ + --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ + --in-channels 3 --out-channels 3 --norms cosmology.dis --augment \ + --epochs 128 --batches-per-gpu 4 --loader-workers-per-gpu 4 +# --load-state checkpoint.pth + + +date diff --git a/scripts/m2m.py b/scripts/m2m.py new file mode 100644 index 0000000..ea49edf --- /dev/null +++ b/scripts/m2m.py @@ -0,0 +1,5 @@ +from map2map.main import main + + +if __name__ == '__main__': + main() diff --git a/scripts/vel2vel.slurm b/scripts/vel2vel.slurm new file mode 100644 index 0000000..f765282 --- /dev/null +++ b/scripts/vel2vel.slurm @@ -0,0 +1,53 @@ +#!/bin/bash + +#SBATCH --job-name=vel2vel +#SBATCH --dependency=singleton +#SBATCH --output=%x-%j.out +#SBATCH --error=%x-%j.err + +#SBATCH --partition=gpu +#SBATCH --gres=gpu:v100-32gb:4 + +#SBATCH --exclusive +#SBATCH --nodes=2 +#SBATCH --mem=0 +#SBATCH --time=2-00:00:00 + + +hostname; pwd; date + + +module load gcc openmpi2 +module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1 + +source $HOME/anaconda3/bin/activate torch + + +export MASTER_ADDR=$HOSTNAME +export MASTER_PORT=8888 + + +data_root_dir="/mnt/ceph/users/yinli/Quijote" + +in_dir="linear" +tgt_dir="nonlin" + +train_dirs="*[1-9]" +val_dirs="*[1-9]0" + +files="vel/128x???.npy" +in_files="$files" +tgt_files="$files" + + +srun m2m.py train \ + --train-in-patterns "$data_root_dir/$in_dir/$train_dirs/$in_files" \ + --train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \ + --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ + --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ + --in-channels 3 --out-channels 3 --norms cosmology.vel --augment \ + --epochs 128 --batches-per-gpu 4 --loader-workers-per-gpu 4 +# --load-state checkpoint.pth + + +date diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..7e3647f --- /dev/null +++ b/setup.py @@ -0,0 +1,20 @@ +from setuptools import setup +from setuptools import find_packages + +setup( + name='map2map', + version='0.0', + description='Neural network emulators to transform field data', + author='Yin Li', + author_email='eelregit@gmail.com', + packages=find_packages(), + install_requires=[ + 'torch', + 'numpy', + 'scipy', + 'tensorboard', + ], + scripts=[ + 'scripts/m2m.py', + ] +)