from argparse import ArgumentParser
from args_main import parse_arguments_main

path_to_monofonic_binary = "/home/aubin/monofonic/build/monofonIC"

def register_arguments_slurm(parser:ArgumentParser):
    """
    Register the arguments for the SLURM parameters used in the slurm submissions scripts for the different binary calls.
    Binary calls:
       - monofonic
       - simbelmyne
       - scola
    """
    parser.add_argument("-smf","--slurm_monofonic", type=str, default=None, help="Path to the monofonic SLURM submission script template.")
    parser.add_argument("-ssbmy","--slurm_simbelmyne", type=str, default=None, help="Path to the simbelmyne SLURM submission script template.")
    parser.add_argument("-sscola","--slurm_scola", type=str, default=None, help="Path to the scola SLURM submission script template.")
    parser.add_argument("--slurm_logs", type=str, default=None, help="Path to the directory where the SLURM logs will be saved.")
    parser.add_argument("--slurm_scripts", type=str, default=None, help="Path to the directory where the SLURM scripts will be saved.")


def parse_arguments_slurm(parsed_args):
    """
    Parse the arguments for the SLURM parameters used in the slurm submissions scripts for the different binary calls.
    """
    from pathlib import Path
    main_dict = parse_arguments_main(parsed_args)
    
    slurm_dict = dict(
        monofonic_template=parsed_args.slurm_monofonic,
        simbelmyne_template=parsed_args.slurm_simbelmyne,
        scola_template=parsed_args.slurm_scola,
        logs=parsed_args.slurm_logs,
        scripts=parsed_args.slurm_scripts
    )

    if slurm_dict["monofonic_template"] is None:
        slurm_dict["monofonic_template"]=main_dict["paramdir"]+"slurm_monofonic.template"
    if slurm_dict["simbelmyne_template"] is None:
        slurm_dict["simbelmyne_template"]=main_dict["paramdir"]+"slurm_simbelmyne.template"
    if slurm_dict["scola_template"] is None:
        slurm_dict["scola_template"]=main_dict["paramdir"]+"slurm_scola.template"
    if slurm_dict["logs"] is None:
        slurm_dict["logs"]=main_dict["directory"]+"slurm_logs/"
    if slurm_dict["scripts"] is None:
        slurm_dict["scripts"]=main_dict["directory"]+"slurm_scripts/"
    

    Path(slurm_dict["logs"]).mkdir(parents=True, exist_ok=True)
    Path(slurm_dict["scripts"]).mkdir(parents=True, exist_ok=True)
    
    return slurm_dict


def create_slurm_template(
            slurm_template:str,
            job_name:str,
            ntasks:int,
            nthreads:int,
            partition:str,
            time:str,
            mem:int,
            log_out:str,
            log_err:str,
            array:tuple|None=None,
                          ):
    """
    Creates a SLURM submission script template.
    """

    with open(slurm_template, "w") as f:
        f.write("#!/bin/bash\n")
        f.write(f"#SBATCH --job-name={job_name}\n")
        f.write(f"#SBATCH --ntasks={ntasks}\n")
        f.write(f"#SBATCH --cpus-per-task={nthreads}\n")
        f.write(f"#SBATCH --partition={partition}\n")
        f.write(f"#SBATCH --time={time}\n")
        f.write(f"#SBATCH --mem={mem}G\n")

        if array is not None:
            f.write(f"#SBATCH --array={array[0]}-{array[1]}\n")
            f.write(f"#SBATCH --output={log_out}%x_%a_%A.out\n")
            f.write(f"#SBATCH --error={log_err}%x_%a_%A.err\n")
        else:
            f.write(f"#SBATCH --output={log_out}%x_%j.out\n")
            f.write(f"#SBATCH --error={log_err}%x_%j.err\n")
        
        f.write("\n")

        f.write("echo '################## SLURM VARIABLES ##################'\n")
        f.write("echo SLURM_JOB_ID: $SLURM_JOB_ID\n")
        f.write("echo SLURM_JOB_NAME: $SLURM_JOB_NAME\n")
        f.write("echo SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST\n")
        f.write("echo SLURM_NNODES: $SLURM_NNODES\n")
        f.write("echo SLURM_NTASKS: $SLURM_NTASKS\n")
        f.write("echo SLURM_CPUS_PER_TASK: $SLURM_CPUS_PER_TASK\n")
        f.write("echo SLURM_JOB_CPUS_PER_NODE: $SLURM_JOB_CPUS_PER_NODE\n")
        f.write("echo SLURM_MEM_PER_CPU: $SLURM_MEM_PER_CPU\n")
        f.write("echo SLURM_MEM_PER_NODE: $SLURM_MEM_PER_NODE\n")
        f.write("echo '#####################################################'\n")
        f.write("\n\n")

        f.write(f"export OMP_NUM_THREADS={nthreads}\n\n")

        f.write("start=`date +%s`\n")
        f.write("%COMMAND%\n")
        f.write("end=`date +%s`\n")
        f.write("runtime=$((end-start))\n")
        f.write("\n")
        f.write("echo ''\n")
        f.write("echo ''\n")
        f.write("echo \"################## RUNTIME ##################\"\n")
        f.write("echo \"Runtime was $runtime seconds\"\n")
        #TODO: fix the two lines below
        # f.write("echo \"Runtime per task was $((${runtime}.0/${SLURM_NTASKS}.0)) seconds\"\n") # not working yet...
        # f.write("echo \"Runtime per total cpus was $((${runtime}.0/${($SLURM_NTASKS*$SLURM_CPUS_PER_TASK)}.0)) seconds\"\n")
        f.write("echo \"#############################################\"\n")
        f.write("\n")
        f.write("exit 0\n")


def create_slurm_script(slurm_template:str,
                        slurm_script:str,
                        job:str,
                        job_config_file:str,
                        job_log:str,
                        array:tuple|None=None,
                        job_name:str|None=None,
                        ):
    """
    Creates a SLURM submission script based on the provided template.
    For three different kind of jobs:
         - monofonic
         - simbelmyne
         - scola
     """
    
    if array is not None and job != "scola":
        raise ValueError(f"Array job range provided for job type {job}.")
    if array is None and job == "scola":
        raise ValueError(f"Array job range not provided for job type {job}.")
    
    from os.path import isfile
    if not isfile(slurm_template):
        raise FileNotFoundError(f"SLURM template {slurm_template} does not exist.")
    
    # Copy template content
    with open(slurm_template, "r") as f:
        template = f.readlines()
    
    command_line = ""
    # Add the job command
    match job:
        case "monofonic":
            command_line = f"{path_to_monofonic_binary} {job_config_file} > {job_log}"
        case "simbelmyne":
            command_line = f"{job} {job_config_file} {job_log}"
        case "scola":
            command_line = f"{job} {job_config_file} {job_log} "+"-b ${SLURM_ARRAY_TASK_ID}"
        case _:
            raise ValueError(f"Job type {job} not recognized.")
    
    # Create the script file
    with open(slurm_script, "w") as f:
        for line in template:
            if job_name is not None and "--job-name" in line:
                line = f"#SBATCH --job-name={job_name}\n"
            if array is not None and "--array" in line:
                line = f"#SBATCH --array={array[0]}-{array[1]}\n"
            if array is not None and ("--output" in line  or "--error" in line):
                line = line.replace("%j","%a_%A")
            if "%COMMAND%" in line:
                line = command_line+"\n"
            f.write(line)
    
    

if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser(description="Generate slurm submission templates.")
    parser.add_argument("-j","--job", type=str, default="monofonic", help="Job type: monofonic, simbelmyne, scola.")
    parser.add_argument("-N","--ntasks", type=int, default=1, help="Number of tasks.")
    parser.add_argument("-n","--nthreads", type=int, default=32, help="Number of threads per task.")
    parser.add_argument("-p","--partition", type=str, default="comp,pscomp,compl", help="Partition to use.")
    parser.add_argument("-t","--time", type=str, default="0-00:10:00", help="Time limit.")
    parser.add_argument("-m","--mem", type=int, default=64, help="Memory limit.")
    parser.add_argument("-d", "--directory", type=str, default="./", help="Main directory where the output will be saved (if other dir and filenames are not specified).")
    parser.add_argument("-o","--log_out", type=str, default=None, help="File root for the output logs.")
    parser.add_argument("-e","--log_err", type=str, default=None, help="File root for the error logs.")
    parser.add_argument("-a","--array", type=int, nargs=2, default=None, help="Array job range.")
    parser.add_argument("-s","--slurm_template", type=str, default=None, help="Path to the SLURM template.")
    parser.add_argument("-jn","--job_name", type=str, default=None, help="Job name.")

    parsed_args = parser.parse_args()

    job_name = parsed_args.job if parsed_args.job_name is None else parsed_args.job_name
    slurm_template = parsed_args.slurm_template if parsed_args.slurm_template is not None else f"{parsed_args.directory}params/slurm_{job_name}.template"
    log_out = parsed_args.log_out if parsed_args.log_out is not None else f"{parsed_args.directory}slurm_logs/{job_name}_"
    log_err = parsed_args.log_err if parsed_args.log_err is not None else f"{parsed_args.directory}slurm_logs/{job_name}_"

    create_slurm_template(
        slurm_template=slurm_template,
        job_name=job_name,
        ntasks=parsed_args.ntasks,
        nthreads=parsed_args.nthreads,
        partition=parsed_args.partition,
        time=parsed_args.time,
        mem=parsed_args.mem,
        log_out=log_out,
        log_err=log_err,
        array=parsed_args.array,
    )