map2map/map2map/cropper.py
2024-04-05 09:11:23 +02:00

36 lines
1.3 KiB
Python

import click
import numpy as np
import h5py as h5
import pathlib
from tqdm import tqdm
def _extract_3d_tile_periodic(arr, tile_size, start_index):
periodic_indices = map(
lambda a: a[0] + a[1],
zip(np.ogrid[:tile_size, :tile_size, :tile_size], start_index),
)
periodic_indices = map(
lambda a: np.mod(a[0], a[1]), zip(periodic_indices, arr.shape)
)
return arr[tuple(periodic_indices)]
@click.command()
@click.option("--input", required=True, type=click.Path(exists=True), help="Input file")
@click.option("--output", required=True, type=click.Path(), help="Output directory")
@click.option(
"--tiles", required=True, type=click.Tuple([int]), help="Size of the tiles"
)
@click.option("--fields", required=True, type=click.Tuple([str]), help="Fields to crop")
@click.option("--num_tiles", required=True, type=int, help="Number of tiles to crop")
def cropper(input, output, tiles, fields, num_tiles):
output = pathlib.PosixPath(output)
with h5.File(input, mode="r") as f:
for i in tqdm(range(num_tiles)):
a, b, c = np.random.randint(0, high=1024, size=3)
for field in fields:
tile = _extract_3d_tile_periodic(f[field], Q, (a, b, c))
np.save(output / "tiles" / field / "{:04d}.npy".format(i), tile)