# Copyright (C) 2022 Richard Stiskalek, Harry Desmond
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""
Support for reading the PHEW/ACACIA CSiBORG merger trees. However, note that
the merger trees are very unreliable.
"""
from abc import ABC
from datetime import datetime
from gc import collect

import numpy
from h5py import File
from tqdm import tqdm, trange
from treelib import Tree

from ..utils import periodic_distance
from .paths import Paths

###############################################################################
#                          Utility functions.                                 #
###############################################################################


def clump_identifier(clump, nsnap):
    """
    Generate a unique identifier for a clump at a given snapshot.

    Parameters
    ----------
    clump : int
        Clump ID.
    nsnap : int
        Snapshot index.

    Returns
    -------
    str
    """
    return f"{str(clump).rjust(9, 'x')}__{str(nsnap).rjust(4, 'x')}"


def extract_identifier(identifier):
    """
    Extract the clump ID and snapshot index from a identifier generated by
    `clump_identifier`.

    Parameters
    ----------
    identifier : str
        Identifier.

    Returns
    -------
    clump, nsnap : int
        Clump ID and snapshot index.
    """
    clump, nsnap = identifier.split('__')
    return int(clump.lstrip('x')), int(nsnap.lstrip('x'))


###############################################################################
#                       Merger tree reader class.                             #
###############################################################################


class BaseMergerReader(ABC):
    """
    Base class for the CSiBORG merger tree reader.
    """
    _paths = None
    _nsim = None
    _min_snap = None
    _cache = {}

    @property
    def paths(self):
        """Paths manager."""
        if self._paths is None:
            raise ValueError("`paths` is not set.")
        return self._paths

    @paths.setter
    def paths(self, paths):
        assert isinstance(paths, Paths)
        self._paths = paths

    @property
    def nsim(self):
        """Simulation index."""
        if self._nsim is None:
            raise ValueError("`nsim` is not set.")
        return self._nsim

    @nsim.setter
    def nsim(self, nsim):
        assert isinstance(nsim, (int, numpy.integer))
        self._nsim = nsim

    @property
    def min_snap(self):
        """Minimum snapshot index to read."""
        return self._min_snap

    @min_snap.setter
    def min_snap(self, min_snap):
        if min_snap is not None:
            assert isinstance(min_snap, (int, numpy.integer))
            self._min_snap = int(min_snap)

    def cache_length(self):
        """Length of the cache."""
        return len(self._cache)

    def cache_clear(self):
        """Clear the cache."""
        self._cache = {}
        collect()

    def __getitem__(self, key):
        try:
            return self._cache[key]
        except KeyError:
            fname = self.paths.processed_merger_tree(self.nsim)

            nsnap, kind = key.split("__")

            with File(fname, "r") as f:
                if kind == "clump_to_array":
                    cl = self[f"{nsnap}__clump"]

                    x = {}
                    for i, c in enumerate(cl):
                        if c in x:
                            x[c] += (i,)
                        else:
                            x[c] = (i,)
                else:
                    x = f[f"{str(nsnap)}/{kind}"][:]

            # Cache it
            self._cache[key] = x

            return x


class MergerReader(BaseMergerReader):
    """
    Merger tree reader.

    Parameters
    ----------
    nsim : int
        Simulation index.
    paths : Paths
        Paths manager.
    min_snap : int
        Minimum snapshot index. Trees below this snapshot will not be read.
    """
    def __init__(self, nsim, paths, min_snap=None):
        self.nsim = nsim
        self.paths = paths
        self.min_snap = min_snap

    def get_info(self, current_clump, current_snap, is_main=None):
        """
        Make a list of information about a clump at a given snapshot.

        Parameters
        ----------
        current_clump : int
            Clump ID.
        current_snap : int
            Snapshot index.
        is_main : bool
            Whether this is the main progenitor.

        Returns
        -------
        list
        """
        if current_clump < 0:
            raise ValueError("Clump ID must be positive.")

        if is_main is not None and not isinstance(is_main, bool):
            raise ValueError("`is_main` must be a boolean.")

        k = self[f"{current_snap}__clump_to_array"][current_clump][0]

        out = [self[f"{current_snap}__desc_mass"][k],
               *self[f"{current_snap}__desc_pos"][k][::-1]] # TODO REMOVE LATER

        if is_main is not None:
            return [is_main,] + out

        return out

    def get_mass(self, clump, snap):
        """
        Get the mass of a clump at a given snapshot.

        Parameters
        ----------
        clump : int
            Clump ID.
        snap : int
            Snapshot index.

        Returns
        -------
        float
        """
        if clump < 0:
            raise ValueError("Clump ID must be positive.")
        k = self[f"{snap}__clump_to_array"][clump][0]
        return self[f"{snap}__desc_mass"][k]

    def get_pos(self, clump, snap):
        if clump < 0:
            raise ValueError("Clump ID must be positive.")
        k = self[f"{snap}__clump_to_array"][clump][0]
        return self[f"{snap}__desc_pos"][k]

    def find_main_progenitor(self, clump, nsnap):
        """
        Find the main progenitor of a clump at a given snapshot. Cases are:
            - `clump > 0`, `progenitor > 0`: main progenitor is in the adjacent
            snapshot,
            - `clump > 0`, `progenitor < 0`: main progenitor is not in the
            adjacent snapshot.
            - `clump < 0`, `progenitor = 0`: no progenitor, newly formed clump.

        Parameters
        ----------
        clump : int
            Clump ID.
        nsnap : int
            Snapshot index.

        Returns
        -------
        progenitor : int
            Main progenitor clump ID.
        progenitor_snap : int
            Main progenitor snapshot index.
        """
        if not clump > 0:
            raise ValueError("Clump ID must be positive.")

        cl2array = self[f"{nsnap}__clump_to_array"]
        if clump in cl2array:
            k = cl2array[clump]
        else:
            raise ValueError("Clump ID not found.")

        if len(k) > 1:
            raise ValueError("Found more than one main progenitor.")
        k = k[0]

        progenitor = abs(self[f"{nsnap}__progenitor"][k])
        progenitor_snap = self[f"{nsnap}__progenitor_outputnr"][k]

        if (self.min_snap is not None) and (nsnap < self.min_snap):
            return 0, numpy.nan

        return progenitor, progenitor_snap

    def find_minor_progenitors(self, clump, nsnap):
        """
        Find the minor progenitors of a clump at a given snapshot. This means
        that `clump < 0`, `progenitor > 0`, i.e. this clump also has another
        main progenitor.

        If there are no minor progenitors, return `None` for both lists.

        Parameters
        ----------
        clump : int
            Clump ID.
        nsnap : int
            Snapshot index.

        Returns
        -------
        prog : list
            List of minor progenitor clump IDs.
        prog_snap : list
            List of minor progenitor snapshot indices.
        """
        if not clump > 0:
            raise ValueError("Clump ID must be positive.")

        try:
            ks = self[f"{nsnap}__clump_to_array"][-clump]
        except KeyError:
            return None, None

        prog = [self[f"{nsnap}__progenitor"][k] for k in ks]
        prog_nsnap = [self[f"{nsnap}__progenitor_outputnr"][k] for k in ks]

        if (self.min_snap is not None) and (nsnap < self.min_snap):
            return None, None

        return prog, prog_nsnap

    def find_progenitors(self, clump, nsnap):
        """
        Find all progenitors of a clump at a given snapshot. The main
        progenitor is the first element of the list.

        Parameters
        ----------
        clump : int
            Clump ID.
        nsnap : int
            Snapshot index.

        Returns
        -------
        prog : list
            List of progenitor clump IDs.
        prog_nsnap : list
            List of progenitor snapshot indices.
        """
        main_prog, main_prog_nsnap = self.find_main_progenitor(clump, nsnap)
        min_prog, min_prog_nsnap = self.find_minor_progenitors(clump, nsnap)

        # Check that if the main progenitor is not in the adjacent snapshot,
        # then the minor progenitor are also in that snapshot (if any).
        if (min_prog is not None) and (main_prog_nsnap != nsnap - 1) and not all(prog_nsnap == mprog for mprog in min_prog_nsnap):  # noqa
            raise ValueError(f"For clump {clump} at snapshot {nsnap} we have "
                             f"main progenitor at {main_prog_nsnap} and "
                             "minor progenitors at {min_prog_nsnap}.")

        if min_prog is None:
            prog = [main_prog,]
            prog_nsnap = [main_prog_nsnap,]
        else:
            prog = [main_prog,] + min_prog
            prog_nsnap = [main_prog_nsnap,] + min_prog_nsnap

        if prog[0] == 0 and len(prog) > 1:
            raise ValueError("No main progenitor but minor progenitors "
                             "found for clump {clump} at snapshot {nsnap}.")

        return prog, prog_nsnap

    def tree_mass_at_snapshot(self, clump, nsnap, target_snap):
        """
        Calculate the total mass of nodes in a tree at a given snapshot.
        """
        # If clump is 0 (i.e., we've reached the end of the tree), return 0
        if clump == 0:
            return 0

        # Find the progenitors for the given clump and nsnap
        prog, prog_nsnap = self.find_progenitors(clump, nsnap)

        if prog[0] == 0:
            print(prog)
            return 0

        # Sum the mass of the current clump's progenitors
        tot = 0
        for p, psnap in zip(prog, prog_nsnap):
            if psnap == target_snap:
                tot += self.get_mass(p, psnap)

        # Recursively sum the mass of each progenitor's progenitors
        for p, psnap in zip(prog, prog_nsnap):
            # print("P ", p, psnap)
            tot += self.mass_all_progenitor2(p, psnap, target_snap)

        return tot

    def is_jumper(self, clump, nsnap, nsnap_descendant):
        pass

    def make_tree(self, current_clump, current_nsnap,
                  above_clump=None, above_nsnap=None,
                  tree=None, is_main=None, verbose=False):
        """
        Make a merger tree for a clump at a given snapshot.

        Parameters
        ----------
        current_clump : int
            Clump ID of the descendant clump.
        current_nsnap : int
            Snapshot index of the descendent clump.
        above_clump : int, optional
            Clump ID of a clump above the current clump in the tree.
        above_nsnap : int, optional
            Snapshot index of a clump above the current clump in the tree.
        tree : treelib.Tree, optional
            Tree to add to.
        is_main : bool, optional
            Whether this is the main progenitor.
        verbose : bool, optional
            Verbosity flag.

        Returns
        -------
        treelib.Tree
            Tree with the current clump as the root.
        """
        if verbose:
            print(f"{datetime.now()}: Node of a clump {current_clump} at "
                  f"snapshot {current_nsnap}.", flush=True)

        # Terminate if we are at the end of the tree
        if current_clump == 0:
            return

        # Create the root node or add a new node
        if tree is None:
            tree = Tree()
            tree.create_node(
                "root",
                identifier=clump_identifier(current_clump, current_nsnap),
                data=self.get_info(current_clump, current_nsnap, True),
                )
        else:
            tree.create_node(
                identifier=clump_identifier(current_clump, current_nsnap),
                parent=clump_identifier(above_clump, above_nsnap),
                data=self.get_info(current_clump, current_nsnap, is_main),
                )

        # This returns a list of progenitors and their snapshots. The first
        # element is the main progenitor.
        prog, prog_nsnap = self.find_progenitors(current_clump, current_nsnap)

        for i, (p, psnap) in enumerate(zip(prog, prog_nsnap)):
            self.make_tree(p, psnap, current_clump, current_nsnap, tree,
                           is_main=i == 0, verbose=verbose)

        return tree

    def walk_main_progenitor(self, main_clump, main_nsnap, verbose=False):
        """
        Walk the main progenitor branch of a clump.

        Each snapshot contains information about the clump at that snapshot.

        Parameters
        ----------
        clump : int
            Clump ID.
        nsnap : int
            Snapshot index.

        Returns
        -------
        structured array
        """
        out = []

        pbar = tqdm(disable=not verbose)
        while True:
            prog, prog_nsnap = self.find_progenitors(main_clump, main_nsnap)

            # Unpack the main and minor progenitor
            mainprog, mainprog_nsnap = prog[0], prog_nsnap[0]
            if len(prog) > 1:
                minprog, minprog_nsnap = prog[1:], prog_nsnap[1:]
            else:
                minprog, minprog_nsnap = None, None

            # If there is no progenitor, then set the main progenitor mass to 0
            if mainprog == 0:
                mainprog_mass = numpy.nan
            else:
                mainprog_mass = self.get_mass(mainprog, mainprog_nsnap)

            totprog_mass = mainprog_mass

            # Unpack masses of the progenitors
            if minprog is not None:
                minprog, minprog_nsnap = prog[1:], prog_nsnap[1:]
                minprog_masses = [self.get_mass(c, n)
                                  for c, n in zip(minprog, minprog_nsnap)]

                max_minprog_mass = max(minprog_masses)
                minprog_totmass = sum(minprog_masses)
                totprog_mass += minprog_totmass
            else:
                minprog_totmass = numpy.nan
                max_minprog_mass = numpy.nan

            out += [
                [main_nsnap,]
                + self.get_info(main_clump, main_nsnap)
                + [mainprog_nsnap, totprog_mass, mainprog_mass, minprog_totmass, max_minprog_mass / mainprog_mass]  # noqa
                ]

            pbar.update(1)
            pbar.set_description(f"Clump {main_clump} ({main_nsnap})")

            if mainprog == 0:
                pbar.close()
                break

            main_clump = mainprog
            main_nsnap = mainprog_nsnap

        # Convert output to a structured array. We store integers as float
        # to avoid errors because of converting NaNs to integers.
        out = numpy.vstack(out)
        dtype = [("desc_snapshot_index", numpy.float32),
                 ("desc_mass", numpy.float32),
                 ("desc_x", numpy.float32),
                 ("desc_y", numpy.float32),
                 ("desc_z", numpy.float32),
                 ("prog_snapshot_index", numpy.float32),
                 ("prog_totmass", numpy.float32),
                 ("mainprog_mass", numpy.float32),
                 ("minprog_totmass", numpy.float32),
                 ("merger_ratio", numpy.float32),
                 ]

        return numpy.array([tuple(row) for row in out], dtype=dtype)

    def match_mass_to_phewcat(self, phewcat):
        """
        For each clump mass in the PHEW catalogue, find the corresponding
        clump mass in the merger tree file. If no match is found returns NaN.
        These are not equal because the PHEW catalogue mass is the mass without
        unbinding.

        Parameters
        ----------
        phewcat : csiborgtools.read.CSiBORGPEEWReader
            PHEW catalogue reader.

        Returns
        -------
        mass : float
        """
        if phewcat.nsim != self.nsim:
            raise ValueError("Simulation indices do not match.")

        nsnap = phewcat.nsnap
        indxs = phewcat["index"]
        mergertree_mass = numpy.full(len(indxs), numpy.nan,
                                     dtype=numpy.float32)

        for i, ind in enumerate(indxs):
            try:
                mergertree_mass[i] = self.get_mass(ind, nsnap)
            except KeyError:
                continue

        return mergertree_mass

    def match_pos_to_phewcat(self, phewcat):
        """
        For each clump mass in the PHEW catalogue, find the corresponding
        clump mass in the merger tree file. If no match is found returns NaN.
        These are not equal because the PHEW catalogue mass is the mass without
        unbinding.

        Parameters
        ----------
        phewcat : csiborgtools.read.CSiBORGPEEWReader
            PHEW catalogue reader.

        Returns
        -------
        mass : float
        """
        if phewcat.nsim != self.nsim:
            raise ValueError("Simulation indices do not match.")

        nsnap = phewcat.nsnap
        indxs = phewcat["index"]
        mergertree_pos = numpy.full((len(indxs), 3), numpy.nan,
                                    dtype=numpy.float32)

        for i, ind in enumerate(indxs):
            try:
                mergertree_pos[i] = self.get_pos(ind, nsnap)
            except KeyError:
                continue

        return mergertree_pos[:, ::-1]  # TODO later remove


###############################################################################
#                           Manual halo tracking.                             #
###############################################################################


def track_halo_manually(cats, hid, maxdist=0.15, max_dlogm=0.35):
    """
    Manually track a halo without using the merger tree. Searches for nearby
    halo of similar mass in adjacent snapshots. Supports only main haloes and
    can only work for the most massive haloes in a simulation, however even
    then significant care should be taken.

    Selects the most massive halo within a search radius to be a match.

    In case a progenitor is not found in the adjacent snapshot, the search
    continues in the next snapshot. Occasionally some haloes disappear..

    Parameters
    ----------
    cats : dict
        Dictionary of halo catalogues, keys are snapshot indices.
    hid : int
        Halo ID.
    maxdist : float, optional
        Maximum comoving distance for a halo to move between adjacent
        snapshots.
    max_dlogm : float, optional
        Maximum |log mass ratio| for a halo to be considered a progenitor.

    Returns
    -------
    hist : structured array
        History of the halo.
    """
    nsnap0 = max(cats.keys())
    k = cats[nsnap0]["hid_to_array_index"][hid]
    pos = cats[nsnap0]["cartesian_pos"][k]
    mass = cats[nsnap0]["summed_mass"][k]

    if not cats[nsnap0]["is_main"][k]:
        raise ValueError("Only main haloes are supported.")

    if not mass > 1e13:
        raise ValueError("Only the most massive haloes are supported.")

    if not cats[nsnap0]["dist"][k] < 155.5:
        raise ValueError("Only high-resolution region haloes are supported.")

    dtype = [("snapshot_index", numpy.float32),
             ("x", numpy.float32),
             ("y", numpy.float32),
             ("z", numpy.float32),
             ("mass", numpy.float32),
             ("desc_dist", numpy.float32),
             ]
    hist = numpy.full(len(cats), numpy.nan, dtype=dtype)
    hist["snapshot_index"][0] = nsnap0
    hist["x"][0], hist["y"][0], hist["z"][0] = pos
    hist["mass"][0] = mass

    for n in trange(1, len(cats), desc="Tracking halo"):
        nsnap = nsnap0 - n
        hist["snapshot_index"][n] = nsnap

        # Find indices of all main haloes that are within a box of width
        indxs = cats[nsnap].select_in_box(pos, 2 * maxdist)

        if len(indxs) == 0:
            continue

        nearby_pos = cats[nsnap]["cartesian_pos"][indxs]
        nearby_mass = cats[nsnap]["summed_mass"][indxs]

        # Distance from the previous position and |log mass ratio|
        dist = periodic_distance(nearby_pos, pos, cats[nsnap].box.boxsize)
        dlogm = numpy.abs(numpy.log10(nearby_mass / mass))
        k = numpy.argmin(dlogm)

        if (dlogm[k] < max_dlogm) & (dist[k] < maxdist):
            hist["x"][n], hist["y"][n], hist["z"][n] = nearby_pos[k]
            hist["mass"][n] = nearby_mass[k]
            hist["desc_dist"][n] = dist[k]

            pos = nearby_pos[k]
            mass = nearby_mass[k]

    return hist