diff --git a/map2map/cropper.py b/map2map/cropper.py new file mode 100644 index 0000000..52ad864 --- /dev/null +++ b/map2map/cropper.py @@ -0,0 +1,35 @@ +import click +import numpy as np +import h5py as h5 +import pathlib +from tqdm import tqdm + + +def _extract_3d_tile_periodic(arr, tile_size, start_index): + periodic_indices = map( + lambda a: a[0] + a[1], + zip(np.ogrid[:tile_size, :tile_size, :tile_size], start_index), + ) + periodic_indices = map( + lambda a: np.mod(a[0], a[1]), zip(periodic_indices, arr.shape) + ) + return arr[tuple(periodic_indices)] + + +@click.command() +@click.option("--input", required=True, type=click.Path(exists=True), help="Input file") +@click.option("--output", required=True, type=click.Path(), help="Output directory") +@click.option( + "--tiles", required=True, type=click.Tuple([int]), help="Size of the tiles" +) +@click.option("--fields", required=True, type=click.Tuple([str]), help="Fields to crop") +@click.option("--num_tiles", required=True, type=int, help="Number of tiles to crop") +def cropper(input, output, tiles, fields, num_tiles): + output = pathlib.PosixPath(output) + + with h5.File(input, mode="r") as f: + for i in tqdm(range(num_tiles)): + a, b, c = np.random.randint(0, high=1024, size=3) + for field in fields: + tile = _extract_3d_tile_periodic(f[field], Q, (a, b, c)) + np.save(output / "tiles" / field / "{:04d}.npy".format(i), tile) diff --git a/map2map/main.py b/map2map/main.py index 63782f0..7b65930 100644 --- a/map2map/main.py +++ b/map2map/main.py @@ -1,4 +1,3 @@ -from .args import get_args from . import train from . import test import click diff --git a/pyproject.toml b/pyproject.toml index 7b5734e..0d628dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ 'scipy', 'matplotlib', 'tensorboard', + 'h5py','tqdm', 'click','pyyaml'] authors = [ @@ -31,6 +32,7 @@ maintainers = [ [project.scripts] m2m = "map2map:main.main" +mapcropper = "map2map:cropper.cropper" [project.urls] #Homepage = "https://example.com"