Add autocropper

This commit is contained in:
Guilhem Lavaux 2024-04-05 09:11:23 +02:00
parent ad2e01cb03
commit 2d263287e6
3 changed files with 37 additions and 1 deletions

35
map2map/cropper.py Normal file
View File

@ -0,0 +1,35 @@
import click
import numpy as np
import h5py as h5
import pathlib
from tqdm import tqdm
def _extract_3d_tile_periodic(arr, tile_size, start_index):
periodic_indices = map(
lambda a: a[0] + a[1],
zip(np.ogrid[:tile_size, :tile_size, :tile_size], start_index),
)
periodic_indices = map(
lambda a: np.mod(a[0], a[1]), zip(periodic_indices, arr.shape)
)
return arr[tuple(periodic_indices)]
@click.command()
@click.option("--input", required=True, type=click.Path(exists=True), help="Input file")
@click.option("--output", required=True, type=click.Path(), help="Output directory")
@click.option(
"--tiles", required=True, type=click.Tuple([int]), help="Size of the tiles"
)
@click.option("--fields", required=True, type=click.Tuple([str]), help="Fields to crop")
@click.option("--num_tiles", required=True, type=int, help="Number of tiles to crop")
def cropper(input, output, tiles, fields, num_tiles):
output = pathlib.PosixPath(output)
with h5.File(input, mode="r") as f:
for i in tqdm(range(num_tiles)):
a, b, c = np.random.randint(0, high=1024, size=3)
for field in fields:
tile = _extract_3d_tile_periodic(f[field], Q, (a, b, c))
np.save(output / "tiles" / field / "{:04d}.npy".format(i), tile)

View File

@ -1,4 +1,3 @@
from .args import get_args
from . import train
from . import test
import click

View File

@ -38,6 +38,7 @@ dependencies = [
'scipy',
'matplotlib',
'tensorboard',
'h5py','tqdm',
'click','pyyaml']
authors = [
@ -51,6 +52,7 @@ maintainers = [
[project.scripts]
m2m = "map2map:main.main"
mapcropper = "map2map:cropper.cropper"
[tool.poetry.scripts]
map2map = "map2map:main"