from os.path import isfile
from pysbmy.timestepping import StandardTimeStepping
import numpy as np
from argparse import ArgumentParser

def register_arguments_timestepping(parser:ArgumentParser):
    """
    Register the arguments for the timestepping.
    """
    parser.add_argument("-nt","--nsteps", type=int, default=10, help="Number of timesteps.")
    parser.add_argument("--integrator", type=str, default="COLAm", help="Integrator to use.")
    parser.add_argument("--TimeStepDistribution", type=str, default="a", help="Time step distribution.")
    parser.add_argument("--Snapshots", type=int, nargs="*", default=None, help="Snapshots of steps to save.")
    parser.add_argument("--n_LPT", type=float, default=-2.5, help="Modified discretisation parameters for COLAm.")


def parse_arguments_timestepping(parsed_args):
    """
    Parse the arguments for the timestepping.
    """
    from parameters_card import parse_arguments_card_for_timestepping
    from cosmo_params import parse_arguments_cosmo, z2a

    card_dict = parse_arguments_card_for_timestepping(parsed_args)
    cosmo_dict = parse_arguments_cosmo(parsed_args)

    timestepping_dict = dict(
            ai=z2a(card_dict["RedshiftLPT"]),
            af=z2a(card_dict["RedshiftFCs"]),
            nsteps=parsed_args.nsteps,
            n_LPT=parsed_args.n_LPT,
            cosmo=cosmo_dict,
            lightcone=card_dict["GenerateLightcone"],
            )

    ts_filename = card_dict["TimeSteppingFileName"] 
    
    match parsed_args.integrator:
        case "PM" | "StandardLeapfrog":
            timestepping_dict["integrator"] = 0
        case "COLA":
            timestepping_dict["integrator"] = 1
        case "COLAm" | "COLA_mod":
            timestepping_dict["integrator"] = 2
        case "BF" | "BullFrog":
            timestepping_dict["integrator"] = 3
        case "LPT":
            timestepping_dict["integrator"] = 4
        case _:
            raise ValueError(f"Integrator {parsed_args.integrator} not recognised.")
    
    match parsed_args.TimeStepDistribution:
        case "a" | "lin_a" | "linear":
            timestepping_dict["TimeStepDistribution"] = 0
        case "log" | "log_a" | "logarithmic":
            timestepping_dict["TimeStepDistribution"] = 1
        case "exp" | "exp_a" | "exponential":
            timestepping_dict["TimeStepDistribution"] = 2
        case "D" | "lin_D" | "growth":
            timestepping_dict["TimeStepDistribution"] = 3
        case _:
            raise ValueError(f"Time step distribution {parsed_args.TimeStepDistribution} not recognised.")
    
    snapshots = np.zeros(parsed_args.nsteps)
    if parsed_args.Snapshots is not None:
        for snap in parsed_args.Snapshots:
            if snap < 0 or snap >= parsed_args.nsteps:
                raise ValueError(f"Snapshot {snap} is out of range.")
            snapshots[snap] = 1
    timestepping_dict["snapshots"] = snapshots

    return timestepping_dict, ts_filename


def create_timestepping(timestepping_dict, ts_filename:str, verbose:int=1):
    """
    Main function for the timestepping.
    """
    TS = StandardTimeStepping(**timestepping_dict)
    if verbose < 2:
        from io import BytesIO
        from low_level import stdout_redirector, stderr_redirector
        f = BytesIO()
        g = BytesIO()
        with stdout_redirector(f):
            with stderr_redirector(g):
                TS.write(ts_filename)
            g.close()
        f.close()
    else:
        TS.write(ts_filename)


def main_timestepping(parsed_args):
    """
    Main function for the timestepping.
    """
    from low_level import print_message, print_ending_module, print_starting_module

    print_starting_module("timestepping", verbose=parsed_args.verbose)
    print_message("Parsing arguments for the timestepping file.", 1, "timestepping", verbose=parsed_args.verbose)
    timestepping_dict, ts_filename = parse_arguments_timestepping(parsed_args)
    if isfile(ts_filename) and not parsed_args.force:
        print_message(f"Timestepping file {ts_filename} already exists. Use -F to overwrite.", 1, "timestepping", verbose=parsed_args.verbose)
        return timestepping_dict
    create_timestepping(timestepping_dict, ts_filename, verbose=parsed_args.verbose)
    print_message(f"Timestepping file written to {ts_filename}", 2, "timestepping", verbose=parsed_args.verbose)
    print_ending_module("timestepping", verbose=parsed_args.verbose)
    
    return timestepping_dict

if __name__ == "__main__":
    from args_main import register_arguments_main
    from parameters_card import register_arguments_card_for_timestepping
    from cosmo_params import register_arguments_cosmo

    parser = ArgumentParser(description="Create timestepping file.")
    # TODO: reduce the volume of arguments
    register_arguments_main(parser)
    register_arguments_timestepping(parser)
    register_arguments_card_for_timestepping(parser)
    register_arguments_cosmo(parser)
    parsed_args = parser.parse_args()
    main_timestepping(parsed_args)