36 lines
1.3 KiB
Python
36 lines
1.3 KiB
Python
import click
|
|
import numpy as np
|
|
import h5py as h5
|
|
import pathlib
|
|
from tqdm import tqdm
|
|
|
|
|
|
def _extract_3d_tile_periodic(arr, tile_size, start_index):
|
|
periodic_indices = map(
|
|
lambda a: a[0] + a[1],
|
|
zip(np.ogrid[:tile_size, :tile_size, :tile_size], start_index),
|
|
)
|
|
periodic_indices = map(
|
|
lambda a: np.mod(a[0], a[1]), zip(periodic_indices, arr.shape)
|
|
)
|
|
return arr[tuple(periodic_indices)]
|
|
|
|
|
|
@click.command()
|
|
@click.option("--input", required=True, type=click.Path(exists=True), help="Input file")
|
|
@click.option("--output", required=True, type=click.Path(), help="Output directory")
|
|
@click.option(
|
|
"--tiles", required=True, type=click.Tuple([int]), help="Size of the tiles"
|
|
)
|
|
@click.option("--fields", required=True, type=click.Tuple([str]), help="Fields to crop")
|
|
@click.option("--num_tiles", required=True, type=int, help="Number of tiles to crop")
|
|
def cropper(input, output, tiles, fields, num_tiles):
|
|
output = pathlib.PosixPath(output)
|
|
|
|
with h5.File(input, mode="r") as f:
|
|
for i in tqdm(range(num_tiles)):
|
|
a, b, c = np.random.randint(0, high=1024, size=3)
|
|
for field in fields:
|
|
tile = _extract_3d_tile_periodic(f[field], Q, (a, b, c))
|
|
np.save(output / "tiles" / field / "{:04d}.npy".format(i), tile)
|