import numpy as np
import plotly.graph_objects as go
import atom_properties as atomp
from itertools import accumulate

class XYZ:
    sphere_mode = "ball"

    def __init__(self, coords: list[dict[str, str|float]]):
        self.xyz = np.array([[coord["x"], coord["y"], coord["z"]] for coord in coords])
        self.elements = [coord["element"] for coord in coords]


    @classmethod
    def from_file(cls, file: str):
        """
        Reads in files with the structure 'element x y z' and returns
        a numpy array with the stored coordinates.
        @param file: String. Path to the file to be read in.
        @return: 2D numpy array. Contains the coordinates from the file at the
            corresponding indices.
        """
        with open(file) as f:
            lines = f.readlines()
        if len(lines[0].split()) == 1:
            lines = lines[2:]
            return cls([{"element": line.split()[0],
                         "x": float(line.split()[1]),
                         "y": float(line.split()[2]),
                         "z": float(line.split()[3])} for line in lines])
    
    
    def get_bond_length(self, i: int, j: int) -> float:
        """
        Returns the bond length between atoms i and j.

        Args:
            i (int): Index of the first atom.
            j (int): Index of the second atom.

        Returns:
            float: The bond length between the two atoms.
        """
        return float(np.linalg.norm(self.xyz[i] - self.xyz[j]))

    def circle_coordinates(self, point: np.ndarray, v: np.ndarray, radius: float, num_points: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Generate cartesian coordinates of a circle in 3D space.

        Args:
            point (np.ndarray): A point on the circle.
            v (np.ndarray): A vector perpendicular to the circle.
            radius (float): The radius of the circle.
            num_points (int): The number of points to generate on the circle.

        Returns:
            np.ndarray: A 2D array of shape (num_points, 3) containing the cartesian coordinates of the circle.
        """

        # Normalize the vector v
        v = v / np.linalg.norm(v)

        # Create a vector u that is perpendicular to v
        u = np.cross(v, [1, 0, 0]) if np.linalg.norm(np.cross(v, [1, 0, 0])) > 1e-6 else np.cross(v, [0, 1, 0])
        u = u / np.linalg.norm(u)

        # Create a vector w that is perpendicular to both u and v
        w = np.cross(u, v)

        # Generate the cartesian coordinates of the circle
        theta = np.linspace(0, 2 * np.pi, num_points)
        circle_coords = point + radius * (u * np.cos(theta)[:, np.newaxis] + w * np.sin(theta)[:, np.newaxis])

        x, y, z = circle_coords[:, 0], circle_coords[:, 1], circle_coords[:, 2]
        return x, y, z

    def make_fibonacci_sphere(self, center: np.ndarray, radius: float = 0.1, resolution: int = 32) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Return cartesian coordinates of points evenly distributed on the surface of a sphere.

        Args:
            center (np.ndarray): The coordinates of the center of the sphere.
            radius (float, optional): The radius of the sphere. Defaults to 0.1.
            resolution (int, optional): The number of points to be generated. Defaults to 32.

        Returns:
            tuple[np.ndarray, np.ndarray, np.ndarray]: The cartesian coordinates of the points
                on the surface of the sphere. The three arrays are the x, y and z coordinates
                of the points.
        """
        num_points = resolution
        indices = np.arange(0, num_points, dtype=float) + 0.5
        phi = np.arccos(1 - 2*indices/num_points)
        theta = np.pi * (1 + 5**0.5) * indices
        
        x = radius * np.sin(phi) * np.cos(theta) + center[0]
        y = radius * np.sin(phi) * np.sin(theta) + center[1]
        z = radius * np.cos(phi) + center[2]

        return x, y, z
    
    def make_cylinder(self, center_i: int, center_j: int, resolution: int = 32, radius: float = 0.1) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Return cartesian coordinates of points on the edges of a cylinder.

        The cylinder is defined by the line between atoms i and j, and the radius is the given value. The number of points is determined by the resolution.

        Args:
            center_i (int): Index of the first atom that defines the cylinder.
            center_j (int): Index of the second atom that defines the cylinder.
            resolution (int, optional): The number of points to be generated. Defaults to 32.
            radius (float, optional): The radius of the cylinder. Defaults to 0.1.

        Returns:
            tuple[np.ndarray, np.ndarray, np.ndarray]: The cartesian coordinates of the points on the edges of the cylinder. The three arrays are the x, y and z coordinates of the points.
        """
        v = self.xyz[center_j] - self.xyz[center_i]
        x1, y1, z1 = self.circle_coordinates(self.xyz[center_i], v, radius=radius, num_points=resolution//4)
        x2, y2, z2 = self.circle_coordinates(self.xyz[center_j], v, radius=radius, num_points=resolution//4)
        return np.concatenate((x1, x2)), np.concatenate((y1, y2)), np.concatenate((z1, z2))
        

    def make_atom_mesh(self, i: int, resolution: int = 32) -> go.Mesh3d:
        """
        Return a Mesh3d object that represents the atom at index i.

        The atom is represented as a sphere with a radius depending on its van der Waals radius.

        Args:
            i (int): Index of the atom to be represented.
            resolution (int, optional): The number of points to be generated. Defaults to 32.

        Returns:
            go.Mesh3d: A Mesh3d object that represents the atom.
        """
        if self.sphere_mode == "vdw":
            radius = atomp.vdw_radii_dict[self.elements[i]]
        elif self.sphere_mode == "ball":
            radius = atomp.vdw_radii_dict[self.elements[i]] * 0.2

        x, y, z = self.make_fibonacci_sphere(self.xyz[i], radius=radius, resolution=resolution)
        
        atom_mesh = go.Mesh3d(
            x=x, 
            y=y, 
            z=z, 
            color=atomp.atom_colors_dict[self.elements[i]], 
            opacity=1, 
            alphahull=0,
            name=f'{self.elements[i]}{i}', # label is too short to also show coordinates
            hoverinfo='name',  # Only show the name on hover
            )
        return atom_mesh
    
    def make_bond_mesh(self, i: int, j: int, resolution: int = 32, radius: float = 0.1) -> go.Mesh3d:
        """
        Return a Mesh3d object that represents a bond between atoms i and j.

        The bond is represented as a cylinder with a radius depending on the given radius.

        Args:
            i (int): Index of the first atom in the bond.
            j (int): Index of the second atom in the bond.
            resolution (int, optional): The number of points to be generated. Defaults to 32.
            radius (float, optional): The radius of the cylinder. Defaults to 0.1.

        Returns:
            go.Mesh3d: A Mesh3d object that represents the bond.
        """
        x, y, z = self.make_cylinder(i, j, resolution=resolution, radius=radius)

        bond_mesh = go.Mesh3d(
            x=x, 
            y=y, 
            z=z, 
            color="#444444", 
            opacity=1, 
            alphahull=0,
            # name=f'{self.elements[i]}-{self.elements[j]}',
            hoverinfo='none',  # No hover info at all
            )
        return bond_mesh

    def get_molecular_mesh(self, resolution: int = 64, rel_cutoff: float = 0.5) -> go.Figure:
        """
        Returns a list of Mesh3d objects that represent the molecular structure.

        Args:
            resolution (int, optional): The number of points to be generated. Defaults to 64.
            rel_cutoff (float, optional): Bond are only shown if the interatomic distance is smaller than the sum of the atomic vdw radii multiplied with this parameter.
                Defaults to 0.5.

        Returns:
            go.Figure: A list of Mesh3d objects that represent the molecular structure.
        """
        mesh_list = []
        # Add the bonds
        for i in range(len(self.elements)):
            for j in range(i+1, len(self.elements)):
                    bond_length = self.get_bond_length(i, j)
                    max_bond_length = rel_cutoff * (atomp.vdw_radii_dict[self.elements[i]] + atomp.vdw_radii_dict[self.elements[j]])
                    if bond_length < max_bond_length:
                        mesh_list.append(self.make_bond_mesh(i, j, resolution=resolution, radius=0.1))

        # Add the atoms
        for i in range(len(self.elements)):
            mesh_list.append(self.make_atom_mesh(i, resolution=resolution))

        return mesh_list


    def get_molecular_viewer(self, resolution: int = 64, rel_cutoff: float = 0.5) -> go.Figure:
        """
        Returns a plotly figure with the molecular structure.

        The molecular structure is represented as a collection of spheres at the atomic positions,
        with the bonds represented as cylinders between the atoms.

        Args:
            resolution (int, optional): The number of points to be generated. Defaults to 64.
            rel_cutoff (float, optional): Bond are only shown if the interatomic distance is smaller than the sum of the atomic vdw radii multiplied with this parameter.
                Defaults to 0.5.

        Returns:
            go.Figure: A plotly figure with the molecular structure.
        """
        fig = go.Figure()

        for mesh in self.get_molecular_mesh(resolution=resolution, rel_cutoff=rel_cutoff):
            fig.add_trace(mesh)

        # Fix aspect ratio such that the molecule is displayed undistorted
        fig.update_layout(scene_aspectmode='data')

        # Remove axes and labels
        # fig.update_scenes(
        #     xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, title=""),
        #     yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, title=""),
        #     zaxis=dict(showgrid=False, zeroline=False, showticklabels=False, title=""),
        # )

        return fig

class Trajectory:
    def __init__(self, coordinates: list[XYZ]):
        self.coordinates = coordinates

    def get_molecular_viewer(self, resolution: int = 64, rel_cutoff: float = 0.5) -> go.Figure:
        """
        Returns a plotly figure with the molecular structure of the trajectory.

        Args:
            resolution (int, optional): The number of points to be generated. Defaults to 64.
            rel_cutoff (float, optional): Bond are only shown if the interatomic distance is smaller than the sum of the atomic vdw radii multiplied with this parameter.
                Defaults to 0.5.

        Returns:
            go.Figure: A plotly figure with the molecular structure.
        """
        # Get and add all the meshes
        mesh_elements = [0]
        fig = go.Figure()
        for xyz in self.coordinates:
            meshes = xyz.get_molecular_mesh(resolution=resolution, rel_cutoff=rel_cutoff)
            mesh_elements.append(len(meshes))
            for mesh in meshes:
                fig.add_trace(mesh)
        mesh_elements_cumulative = list(accumulate(mesh_elements))
        
        # Make the visibility list
        slider_steps: list[dict] = []
        for i in range(len(self.coordinates)):
            step = {
                "method": "restyle",
                "args": [{"visible": [False] * mesh_elements_cumulative[-1]}]
            }
            step["args"][0]["visible"][mesh_elements_cumulative[i]:mesh_elements_cumulative[i+1]] = [True] * mesh_elements[i]
            slider_steps.append(step)
        
        # Make the slider
        slider = {
            "active": 0,
            "currentvalue": {"prefix": "Frame: "},
            "steps": slider_steps
        }

        fig.update_layout(
            sliders=[slider]
        )

        return fig