Skip to content
Snippets Groups Projects
support.py 30.6 KiB
Newer Older
import numpy as np
import plotly.graph_objects as go
import atom_properties as atomp
from itertools import accumulate
import qc_parser as qcp
import subprocess as sub
import os
from typing import Iterable, Sequence

import pyvista as pv
import sklearn.cluster as skc

if "orca" not in os.environ["PATH"]:
    os.environ["PATH"] += ":/home/jovyan/orca"

moritz.buchhorn's avatar
moritz.buchhorn committed

def circle_coordinates(
    point: np.ndarray, v: np.ndarray, radius: float, num_points: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Generate cartesian coordinates of a circle in 3D space.

    Args:
        point (np.ndarray): A point on the circle.
        v (np.ndarray): A vector perpendicular to the circle.
        radius (float): The radius of the circle.
        num_points (int): The number of points to generate on the circle.

    Returns:
        np.ndarray: A 2D array of shape (num_points, 3) containing the cartesian coordinates of the circle.
    """

    # Normalize the vector v
    v = v / np.linalg.norm(v)

    # Create a vector u that is perpendicular to v
moritz.buchhorn's avatar
moritz.buchhorn committed
    u = (
        np.cross(v, [1, 0, 0])
        if np.linalg.norm(np.cross(v, [1, 0, 0])) > 1e-6
        else np.cross(v, [0, 1, 0])
    )
    u = u / np.linalg.norm(u)

    # Create a vector w that is perpendicular to both u and v
    w = np.cross(u, v)

    # Generate the cartesian coordinates of the circle
    theta = np.linspace(0, 2 * np.pi, num_points)
moritz.buchhorn's avatar
moritz.buchhorn committed
    circle_coords = point + radius * (
        u * np.cos(theta)[:, np.newaxis] + w * np.sin(theta)[:, np.newaxis]
    )

    x, y, z = circle_coords[:, 0], circle_coords[:, 1], circle_coords[:, 2]
    return x, y, z

moritz.buchhorn's avatar
moritz.buchhorn committed

def make_arrow_mesh(
    at: np.ndarray, dir: np.ndarray, resolution: int = 16, radius: float = 0.05
) -> go.Mesh3d:
    tip = at + dir
    bottom = at - dir
    v = bottom - tip
moritz.buchhorn's avatar
moritz.buchhorn committed
    x1, y1, z1 = circle_coordinates(
        bottom, v, radius=radius, num_points=resolution // 4
    )
    x2, y2, z2 = circle_coordinates(tip, v, radius=radius, num_points=resolution // 4)
    x, y, z = (
        np.concatenate((x1, x2)),
        np.concatenate((y1, y2)),
        np.concatenate((z1, z2)),
    )

    arrow_mesh = go.Mesh3d(
moritz.buchhorn's avatar
moritz.buchhorn committed
        x=x,
        y=y,
        z=z,
        color="#888888",
        opacity=0.5,
        alphahull=0,
        # name=f'{self.elements[i]}-{self.elements[j]}',
        hoverinfo="none",  # No hover info at all
    )
    return arrow_mesh

moritz.buchhorn's avatar
moritz.buchhorn committed

def make_fibonacci_sphere(
    center: np.ndarray, radius: float = 0.1, resolution: int = 32
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Return cartesian coordinates of points evenly distributed on the surface of a sphere.

    Args:
        center (np.ndarray): The coordinates of the center of the sphere.
        radius (float, optional): The radius of the sphere. Defaults to 0.1.
        resolution (int, optional): The number of points to be generated. Defaults to 32.

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]: The cartesian coordinates of the points
            on the surface of the sphere. The three arrays are the x, y and z coordinates
            of the points.
    """
    num_points = resolution
    indices = np.arange(0, num_points, dtype=float) + 0.5
moritz.buchhorn's avatar
moritz.buchhorn committed
    phi = np.arccos(1 - 2 * indices / num_points)
    theta = np.pi * (1 + 5**0.5) * indices
    x = radius * np.sin(phi) * np.cos(theta) + center[0]
    y = radius * np.sin(phi) * np.sin(theta) + center[1]
    z = radius * np.cos(phi) + center[2]

    return x, y, z

class Volumetric:
moritz.buchhorn's avatar
moritz.buchhorn committed
    def __init__(
        self,
        origin: Sequence[float],
        res_vec1: int,
        res_vec2: int,
        res_vec3: int,
        vec1: Sequence[float],
        vec2: Sequence[float],
        vec3: Sequence[float],
        volumetrics: Sequence[float],
    ) -> None:
        self.origin = np.array(origin)
        self.res_vec1 = res_vec1
        self.res_vec2 = res_vec2
        self.res_vec3 = res_vec3
        self.vec1 = np.array(vec1)
        self.vec2 = np.array(vec2)
        self.vec3 = np.array(vec3)
        self.volumetrics = volumetrics
        return

    @classmethod
    def from_cube_file(cls, file: str):
        with open(file, "r") as f:
            next(f)
            next(f)
            natoms_raw, x0_raw, y0_raw, z0_raw = f.readline().split()
            res_vec1_raw, x1_raw, y1_raw, z1_raw = f.readline().split()
            res_vec2_raw, x2_raw, y2_raw, z2_raw = f.readline().split()
            res_vec3_raw, x3_raw, y3_raw, z3_raw = f.readline().split()

moritz.buchhorn's avatar
moritz.buchhorn committed
            bohr_factor = 1.0
            if int(natoms_raw) != 0:  # should be different, but whatever works
                bohr_factor = 0.52917721092

            natoms, x0, y0, z0 = (
                abs(int(natoms_raw)),
                float(x0_raw) * bohr_factor,
                float(y0_raw) * bohr_factor,
                float(z0_raw) * bohr_factor,
            )
            res_vec1, x1, y1, z1 = (
                int(res_vec1_raw),
                float(x1_raw) * bohr_factor,
                float(y1_raw) * bohr_factor,
                float(z1_raw) * bohr_factor,
            )
            res_vec2, x2, y2, z2 = (
                int(res_vec2_raw),
                float(x2_raw) * bohr_factor,
                float(y2_raw) * bohr_factor,
                float(z2_raw) * bohr_factor,
            )
            res_vec3, x3, y3, z3 = (
                int(res_vec3_raw),
                float(x3_raw) * bohr_factor,
                float(y3_raw) * bohr_factor,
                float(z3_raw) * bohr_factor,
            )

            # skip all the atom coordinates
            for i in range(natoms):
                next(f)
            if ".mo" in file:
                next(f)

            volumetrics_raw = f.readlines()
        volumetrics = [float(i) for line in volumetrics_raw for i in line.split()]
        return cls(
            (x0, y0, z0),
moritz.buchhorn's avatar
moritz.buchhorn committed
            res_vec1,
            res_vec2,
            res_vec3,
            (x1, y1, z1),
            (x2, y2, z2),
            (x3, y3, z3),
moritz.buchhorn's avatar
moritz.buchhorn committed
            volumetrics,
    @classmethod
    def from_gbw_file(cls, file: str, index: int, mode: str):
        assert mode in ("mo", "diffdens")
        if mode == "mo":
            input_string = f"5\n7\n4\n100\n2\n{index}\n10\n11\n"
        if mode == "diffdens":
moritz.buchhorn's avatar
moritz.buchhorn committed
            input_string = f"5\n7\n4\n100\n6\ny\n{index}\n11\n"
        sub.run(
            f"orca_plot {file} -i",
            shell=True,
            text=True,
            capture_output=True,
            input=input_string,
        )
        basename = file.replace(".gbw", "")
        if mode == "mo":
            cubefile = f"{basename}.mo{index}a.cube"
        if mode == "diffdens":
            cubefile = f"{basename}.cisdp{index}.cube"
        out = Volumetric.from_cube_file(cubefile)
        os.remove(cubefile)
    def voxel_generator(self, threshold: float):
        sign = 1
        if threshold < 0:
            sign = -1
            threshold *= -1
        for x in range(self.res_vec1):
            for y in range(self.res_vec2):
                for z in range(self.res_vec3):
moritz.buchhorn's avatar
moritz.buchhorn committed
                    voxel = (
                        self.volumetrics[
                            z + y * self.res_vec3 + x * (self.res_vec3 * self.res_vec2)
                        ]
                        * sign
                    )
                    if voxel > threshold:
moritz.buchhorn's avatar
moritz.buchhorn committed
                        coord = (x * self.vec1) + y * self.vec2 + z * self.vec3
                        yield voxel, *(coord + self.origin)
moritz.buchhorn's avatar
moritz.buchhorn committed

    def get_mesh(
        self,
        threshold: float,
        color: str = "#888888",
        verbose: bool = False,
        dynamic_threshold: bool = False,
    ) -> go.Mesh3d:
        voxels_lst = []
        voxels_lst = [voxel for voxel in self.voxel_generator(threshold)]
moritz.buchhorn's avatar
moritz.buchhorn committed
        if len(voxels_lst) == 0:
            if dynamic_threshold:
                print(f"Threshold too high, finding new threshold ...")
                while len(voxels_lst) == 0:
                    threshold *= 0.9
                    voxels_lst = [voxel for voxel in self.voxel_generator(threshold)]
                print(f"Threshold found: {threshold:.2e}")
                threshold *= 0.1
                voxels_lst = [voxel for voxel in self.voxel_generator(threshold)]
                print(f"Threshold used: {threshold:.2e}")
            else:
                print(f"No voxels found for given threshold {threshold:.2e}")
                return go.Mesh3d()
        voxels = np.array(voxels_lst)
moritz.buchhorn's avatar
moritz.buchhorn committed
        if verbose:
            print("Clustering ...")
        cluster = skc.DBSCAN(eps=0.5).fit(voxels[:, 1:4])
        if verbose:
            print(" done.")
        xyz: np.ndarray = np.ndarray((0, 3))
        ijk: np.ndarray = np.ndarray((0, 3))
        n_clusters = cluster.labels_.max()
        for i in range(n_clusters + 1):
moritz.buchhorn's avatar
moritz.buchhorn committed
            if verbose:
                print(f"Creating mesh {i}/{n_clusters} ...", end="")
            cloud = pv.PolyData(voxels[cluster.labels_ == i][:, 1:4])
            surf = cloud.delaunay_3d(alpha=0.4, tol=0.001, offset=2.5)
            surf = surf.extract_geometry()
            surf = surf.triangulate()
            surf = surf.decimate(0.9)
            ijk = np.concatenate((ijk, surf.faces.reshape(-1, 4)[:, 1:] + len(xyz)))
            xyz = np.concatenate((xyz, surf.points))
moritz.buchhorn's avatar
moritz.buchhorn committed
            if verbose:
                print(f" done.")
        return go.Mesh3d(
moritz.buchhorn's avatar
moritz.buchhorn committed
            x=xyz[:, 0],
            y=xyz[:, 1],
            z=xyz[:, 2],
            i=ijk[:, 0],
            j=ijk[:, 1],
            k=ijk[:, 2],
moritz.buchhorn's avatar
moritz.buchhorn committed
            hoverinfo="none",
moritz.buchhorn's avatar
moritz.buchhorn committed
            color=color,
    # THIS IS AN ALTERNATIVE VERSION OF THE GET_MESH METHOD
    # IT DOES NOT USE CLUSTERING
    # BUT SOMETIMES THE ORBITALS LOOK LESS GOOD THEN
    # def get_mesh(self, threshold: float, color: str = "#888888", verbose: bool = True) -> go.Mesh3d:
    #     voxels_lst = [voxel for voxel in self.voxel_generator(threshold)]
    #     voxels = np.array(voxels_lst)
    #     if verbose: print("Clustering complete")
    #     xyz: np.ndarray = np.ndarray((0, 3))
    #     ijk: np.ndarray = np.ndarray((0, 3))
    #     cloud = pv.PolyData(voxels[:,1:4])
    #     surf = cloud.delaunay_3d(alpha=0.4, tol=0.001, offset=2.5)
    #     surf = surf.extract_geometry()
    #     surf = surf.triangulate()
    #     surf = surf.decimate(0.9)
    #     ijk = surf.faces.reshape(-1, 4)[:, 1:]
    #     xyz = surf.points
    #     return go.Mesh3d(
    #         x=xyz[:,0],
    #         y=xyz[:,1],
    #         z=xyz[:,2],
    #         i=ijk[:,0],
    #         j=ijk[:,1],
    #         k=ijk[:,2],
    #         opacity=0.5,
    #     )


class XYZ:
    """
    A class for representing a molecular structure in cartesian coordinates.

    The molecule is represented by a list of atoms and corresponding coordinates.

    Attributes:
        xyz (ndarray): A 2D numpy array of shape (num_atoms, 3) containing the
            coordinates of the atoms in the molecule.
        elements (list): A list of strings representing the elements of the atoms in
            the molecule.
    """
    sphere_mode = "ball"

moritz.buchhorn's avatar
moritz.buchhorn committed
    def __init__(self, coords: list[dict[str, str | float]]):
        self.xyz = np.array([[coord["x"], coord["y"], coord["z"]] for coord in coords])
        self.elements = [coord["element"] for coord in coords]
        self.meshes: list[go.Mesh3d] = []
        self.viewer: go.Figure = go.Figure()
moritz.buchhorn's avatar
moritz.buchhorn committed
        self.make_molecular_viewer()

    @classmethod
    def from_xyz_file(cls, file: str):
        """
        Reads in files with the structure 'element x y z' and returns
        a numpy array with the stored coordinates.
        @param file: String. Path to the file to be read in.
        @return: 2D numpy array. Contains the coordinates from the file at the
            corresponding indices.
        """
        with open(file) as f:
            lines = f.readlines()
        if len(lines[0].split()) == 1:
            lines = lines[2:]
moritz.buchhorn's avatar
moritz.buchhorn committed
            return cls(
                [
                    {
                        "element": line.split()[0],
                        "x": float(line.split()[1]),
                        "y": float(line.split()[2]),
                        "z": float(line.split()[3]),
                    }
                    for line in lines
                ]
            )

moritz.buchhorn's avatar
moritz.buchhorn committed
    def from_list(cls, lines: list[list[str | float]]):
        return cls(
            [
                {
                    "element": line[0],
                    "x": float(line[1]),
                    "y": float(line[2]),
                    "z": float(line[3]),
                }
                for line in lines
            ]
        )

    def get_bond_length(self, i: int, j: int) -> float:
        """
        Returns the bond length between atoms i and j.

        Args:
            i (int): Index of the first atom.
            j (int): Index of the second atom.

        Returns:
            float: The bond length between the two atoms.
        """
        return float(np.linalg.norm(self.xyz[i] - self.xyz[j]))
moritz.buchhorn's avatar
moritz.buchhorn committed

    def make_cylinder(
        self, center_i: int, center_j: int, resolution: int = 32, radius: float = 0.1
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Return cartesian coordinates of points on the edges of a cylinder.

        The cylinder is defined by the line between atoms i and j, and the radius is the given value. The number of points is determined by the resolution.

        Args:
            center_i (int): Index of the first atom that defines the cylinder.
            center_j (int): Index of the second atom that defines the cylinder.
            resolution (int, optional): The number of points to be generated. Defaults to 32.
            radius (float, optional): The radius of the cylinder. Defaults to 0.1.

        Returns:
            tuple[np.ndarray, np.ndarray, np.ndarray]: The cartesian coordinates of the points on the edges of the cylinder. The three arrays are the x, y and z coordinates of the points.
        """
        v = self.xyz[center_j] - self.xyz[center_i]
moritz.buchhorn's avatar
moritz.buchhorn committed
        x1, y1, z1 = circle_coordinates(
            self.xyz[center_i], v, radius=radius, num_points=resolution // 4
        )
        x2, y2, z2 = circle_coordinates(
            self.xyz[center_j], v, radius=radius, num_points=resolution // 4
        )
        return (
            np.concatenate((x1, x2)),
            np.concatenate((y1, y2)),
            np.concatenate((z1, z2)),
        )

    def make_atom_mesh(self, i: int, resolution: int = 32) -> go.Mesh3d:
        """
        Return a Mesh3d object that represents the atom at index i.

        The atom is represented as a sphere with a radius depending on its van der Waals radius.

        Args:
            i (int): Index of the atom to be represented.
            resolution (int, optional): The number of points to be generated. Defaults to 32.

        Returns:
            go.Mesh3d: A Mesh3d object that represents the atom.
        """
        if self.sphere_mode == "vdw":
            radius = atomp.vdw_radii_dict[self.elements[i]]
        elif self.sphere_mode == "ball":
            radius = atomp.vdw_radii_dict[self.elements[i]] * 0.2

moritz.buchhorn's avatar
moritz.buchhorn committed
        x, y, z = make_fibonacci_sphere(
            self.xyz[i], radius=radius, resolution=resolution
        )

        atom_mesh = go.Mesh3d(
moritz.buchhorn's avatar
moritz.buchhorn committed
            x=x,
            y=y,
            z=z,
            color=atomp.atom_colors_dict[self.elements[i]],
            opacity=1,
moritz.buchhorn's avatar
moritz.buchhorn committed
            name=f"{self.elements[i]}{i}",  # label is too short to also show coordinates
            hoverinfo="name",  # Only show the name on hover
        )
        return atom_mesh
moritz.buchhorn's avatar
moritz.buchhorn committed

    def make_bond_mesh(
        self, i: int, j: int, resolution: int = 32, radius: float = 0.1
    ) -> go.Mesh3d:
        """
        Return a Mesh3d object that represents a bond between atoms i and j.

        The bond is represented as a cylinder with a radius depending on the given radius.

        Args:
            i (int): Index of the first atom in the bond.
            j (int): Index of the second atom in the bond.
            resolution (int, optional): The number of points to be generated. Defaults to 32.
            radius (float, optional): The radius of the cylinder. Defaults to 0.1.

        Returns:
            go.Mesh3d: A Mesh3d object that represents the bond.
        """
        x, y, z = self.make_cylinder(i, j, resolution=resolution, radius=radius)

        bond_mesh = go.Mesh3d(
moritz.buchhorn's avatar
moritz.buchhorn committed
            x=x,
            y=y,
            z=z,
            color="#444444",
            opacity=1,
            alphahull=0,
            # name=f'{self.elements[i]}-{self.elements[j]}',
moritz.buchhorn's avatar
moritz.buchhorn committed
            hoverinfo="none",  # No hover info at all
        )
moritz.buchhorn's avatar
moritz.buchhorn committed
    def get_molecular_mesh(
        self, resolution: int = 64, rel_cutoff: float = 0.5
    ) -> go.Figure:
        """
        Returns a list of Mesh3d objects that represent the molecular structure.

        Args:
            resolution (int, optional): The number of points to be generated. Defaults to 64.
            rel_cutoff (float, optional): Bond are only shown if the interatomic distance is smaller than the sum of the atomic vdw radii multiplied with this parameter.
                Defaults to 0.5.

        Returns:
            go.Figure: A list of Mesh3d objects that represent the molecular structure.
        """
        mesh_list = []
        # Add the bonds
        for i in range(len(self.elements)):
moritz.buchhorn's avatar
moritz.buchhorn committed
            for j in range(i + 1, len(self.elements)):
                bond_length = self.get_bond_length(i, j)
                max_bond_length = rel_cutoff * (
                    atomp.vdw_radii_dict[self.elements[i]]
                    + atomp.vdw_radii_dict[self.elements[j]]
                )
                if bond_length < max_bond_length:
                    mesh_list.append(
                        self.make_bond_mesh(i, j, resolution=resolution, radius=0.1)
                    )

        # Add the atoms
        for i in range(len(self.elements)):
            mesh_list.append(self.make_atom_mesh(i, resolution=resolution))

        return mesh_list

moritz.buchhorn's avatar
moritz.buchhorn committed
    def make_mesh_from_cube(
        self, file: str, threshold=6 * 10**-3, verbose: bool = False
    ) -> None:
        self.viewer.add_trace(
            Volumetric.from_cube_file(file).get_mesh(
                threshold=threshold, verbose=verbose
            )
        )
moritz.buchhorn's avatar
moritz.buchhorn committed

    def make_mesh_from_gbw(
        self,
        file: str,
        index: int,
        mode: str = "mo",
        threshold=6 * 10**-3,
        dynamic_threshold: bool = False,
        verbose: bool = False,
    ) -> None:
moritz.buchhorn's avatar
moritz.buchhorn committed
        if mode == "diffdens":
            color = "tomato" if threshold < 0 else "cornflowerblue"
        if mode == "mo":
            color = "tomato" if threshold < 0 else "cornflowerblue"
moritz.buchhorn's avatar
moritz.buchhorn committed
        self.viewer.add_trace(
            Volumetric.from_gbw_file(file, index, mode).get_mesh(
                threshold=threshold,
                verbose=verbose,
                dynamic_threshold=dynamic_threshold,
moritz.buchhorn's avatar
moritz.buchhorn committed
                color=color,
moritz.buchhorn's avatar
moritz.buchhorn committed
    def make_molecular_viewer(
        self, resolution: int = 64, rel_cutoff: float = 0.5
    ) -> go.Figure:
moritz.buchhorn's avatar
moritz.buchhorn committed
        Add the molecule to the viewer. The molecular structure is represented as a collection of spheres at the atomic positions,
        with the bonds represented as cylinders between the atoms.

        Args:
            resolution (int, optional): The number of points to be generated. Defaults to 64.
            rel_cutoff (float, optional): Bond are only shown if the interatomic distance is smaller than the sum of the atomic vdw radii multiplied with this parameter.
                Defaults to 0.5.

        Returns:
            go.Figure: A plotly figure with the molecular structure.
        """

moritz.buchhorn's avatar
moritz.buchhorn committed
        for mesh in self.get_molecular_mesh(
            resolution=resolution, rel_cutoff=rel_cutoff
        ):
            self.viewer.add_trace(mesh)

        # Fix aspect ratio such that the molecule is displayed undistorted
moritz.buchhorn's avatar
moritz.buchhorn committed
        self.viewer.update_layout(scene_aspectmode="data")

        # Remove axes and labels
        # self.viewer.update_scenes(
        #     xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, title=""),
        #     yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, title=""),
        #     zaxis=dict(showgrid=False, zeroline=False, showticklabels=False, title=""),
        # )

moritz.buchhorn's avatar
moritz.buchhorn committed
        return

    def get_molecular_viewer(self) -> go.Figure:
        return self.viewer

    def reset_molecular_viewer(self) -> go.Figure:
        self.viewer = go.Figure()
        self.make_molecular_viewer()
class Trajectory:
    """
    A class to represent a molecular trajectory in terms of XYZ objects.

    Attributes:
        coordinates (list[XYZ]): A list of XYZ objects representing the atomic coordinates for each frame in the trajectory.
        vibration_vectors (list[list[float]] or None): A list of vibration vectors associated with the trajectory, if any.
    """
    def __init__(self, coordinates: list[XYZ]):
        self.coordinates = coordinates
        self.vibration_vectors = None
    @classmethod
    def from_opt_output(cls, output_file: str):
        """
        Creates a Trajectory object from an ORCA output file.

        Args:
            output_file (str): The path to the ORCA output file.

        Returns:
            Trajectory: A Trajectory object.
        """
        parsed = qcp.read_qc_file(output_file)
        return cls([XYZ(frame) for frame in parsed["coordinates"]])

    @classmethod
    def from_trajectory_file(cls, trajectory_file: str):
        """
        Creates a Trajectory object from a trajectory file.

        Args:
            trajectory_file (str): The path to the trajectory file.

        Returns:
            Trajectory: A Trajectory object.
        """
        return cls([frame for frame in Trajectory.xyz_generator(trajectory_file)])
    @classmethod
    def from_vibration_output(cls, output_file: str, mode: int):
        """
        Creates a Trajectory object from an ORCA output file
        containing a frequency calculation at the given mode.

        Args:
            output_file (str): The path to the ORCA output file.
            mode (int): The mode number.

        Returns:
            Trajectory: A Trajectory object.
        """
moritz.buchhorn's avatar
moritz.buchhorn committed
        proc = sub.run(
            f"orca_pltvib {output_file} {mode}",
            shell=True,
            text=True,
            capture_output=True,
        )
        trajectory_file = f"{output_file}.v{mode:03d}.xyz"
        xyz = Trajectory.from_trajectory_file(trajectory_file)
        xyz.vibration_vectors = Trajectory.get_vibration_vectors(trajectory_file)
        os.remove(trajectory_file)
        return xyz

    @staticmethod
    def xyz_generator(trajectory_file: str):
        """
        Generator that yields XYZ objects from a trajectory file.

        Args:
            trajectory_file (str): The path to the trajectory file.

        Yields:
            XYZ: An XYZ object.
        """
        with open(trajectory_file) as f:
            lines = f.readlines()
        block_size = int(lines[0].strip()) + 2
        for i in range(0, len(lines), block_size):
moritz.buchhorn's avatar
moritz.buchhorn committed
            yield XYZ.from_list(
                [line.split() for line in lines[i + 2 : i + block_size]]
            )

    @staticmethod
    def get_vibration_vectors(trajectory_file: str):
        """
        Returns the vibration vectors from a trajectory file. Must be generated
        by orca_pltvib, so it contains the vectors in the last three columns.

        Args:
            trajectory_file (str): The path to the trajectory file.

        Returns:
            list[list[float]]: A list of vibration vectors.
        """
        with open(trajectory_file) as f:
            lines = f.readlines()
        block_size = int(lines[0].strip())
        lines = lines[2:]
moritz.buchhorn's avatar
moritz.buchhorn committed
        vibration_vectors = [
            [float(element) for element in line.split()[-3:]]
            for line in lines[:block_size]
        ]
        return vibration_vectors
    def get_vibration_vectors_mesh(self) -> list[go.Mesh3d]:
        """
        Returns a list of meshes that represent the vibration vectors.
        """
moritz.buchhorn's avatar
moritz.buchhorn committed
        assert (
            self.vibration_vectors is not None
        ), "Vibration vectors must be generated first"
        meshes = []
        first_frame_coords = self.coordinates[0].xyz
moritz.buchhorn's avatar
moritz.buchhorn committed
        for i, vector in enumerate(self.vibration_vectors):
            meshes.append(make_arrow_mesh(at=first_frame_coords[i], dir=vector))
        return meshes

moritz.buchhorn's avatar
moritz.buchhorn committed
    def get_molecular_viewer_animated(
        self, resolution: int = 64, rel_cutoff: float = 0.5, arrows: bool = False
    ) -> go.Figure:
        """
        Returns a plotly figure with the molecular structure of the trajectory. The figure comes with an animation.

        Args:
            resolution (int, optional): The number of points to be generated. Defaults to 64.
            rel_cutoff (float, optional): Bond are only shown if the interatomic distance is smaller than the sum of the atomic vdw radii multiplied with this parameter.
                Defaults to 0.5.

        Returns:
            go.Figure: A plotly figure with the molecular structure.
        """
        fig = go.Figure()

        # Get and add all the meshes
        mesh_elements = [0]
        frames = []
        for xyz in self.coordinates:
            data = []
moritz.buchhorn's avatar
moritz.buchhorn committed
            meshes = xyz.get_molecular_mesh(
                resolution=resolution, rel_cutoff=rel_cutoff
            )
            mesh_elements.append(len(meshes))
            for mesh in meshes:
moritz.buchhorn's avatar
moritz.buchhorn committed
                data.append(mesh)  # was in double parenthesis
            if arrows and self.vibration_vectors is not None:
                data += self.get_vibration_vectors_mesh()

            frames.append(go.Frame(data=data))
moritz.buchhorn's avatar
moritz.buchhorn committed

        fig = go.Figure(data=data, frames=frames)  # initial frame before animation
        fig.update_layout(
moritz.buchhorn's avatar
moritz.buchhorn committed
            updatemenus=[
                {
                    "type": "buttons",
                    "buttons": [
                        {
                            "label": "Play",
                            "method": "animate",
                            "args": [None, {"frame": {"duration": 200}}],
                        }
                    ],
                }
            ],
            scene_aspectmode="data",
moritz.buchhorn's avatar
moritz.buchhorn committed
    def get_molecular_viewer_slider(
        self, resolution: int = 64, rel_cutoff: float = 0.5, arrows: bool = False
    ) -> go.Figure:
        """
        Returns a plotly figure with the molecular structure of the trajectory. The figure comes with a slider that allows to switch between frames.

        Args:
            resolution (int, optional): The number of points to be generated. Defaults to 64.
            rel_cutoff (float, optional): Bond are only shown if the interatomic distance is smaller than the sum of the atomic vdw radii multiplied with this parameter.
                Defaults to 0.5.

        Returns:
            go.Figure: A plotly figure with the molecular structure.
        """
        # Get and add all the meshes
        mesh_elements = [0]
        fig = go.Figure()
        for xyz in self.coordinates:
moritz.buchhorn's avatar
moritz.buchhorn committed
            meshes = xyz.get_molecular_mesh(
                resolution=resolution, rel_cutoff=rel_cutoff
            )
            if arrows and self.vibration_vectors is not None:
                meshes += self.get_vibration_vectors_mesh()
            mesh_elements.append(len(meshes))
            for mesh in meshes:
                fig.add_trace(mesh)
        mesh_elements_cumulative = list(accumulate(mesh_elements))
        # Make the visibility list
        slider_steps: list[dict] = []
        for i in range(len(self.coordinates)):
            step = {
                "method": "restyle",
moritz.buchhorn's avatar
moritz.buchhorn committed
                "args": [{"visible": [False] * mesh_elements_cumulative[-1]}],
moritz.buchhorn's avatar
moritz.buchhorn committed
            step["args"][0]["visible"][
                mesh_elements_cumulative[i] : mesh_elements_cumulative[i + 1]
            ] = [True] * mesh_elements[i + 1]
            slider_steps.append(step)
        # Make the slider
        slider = {
            "active": 0,
            "currentvalue": {"prefix": "Frame: "},
moritz.buchhorn's avatar
moritz.buchhorn committed
            "steps": slider_steps,
moritz.buchhorn's avatar
moritz.buchhorn committed
        fig.update_layout(sliders=[slider])

        fig.update_layout(scene_aspectmode="data")
        return fig
moritz.buchhorn's avatar
moritz.buchhorn committed


class Spectrum:
    """
    Energies in eV.
    """

    units: dict[str, dict[str, str | float | bool]] = {
        "ev": {"xaxis_title": "Energy / eV", "factor": 1.0, "inverse": False},
        "cm-1": {
            "xaxis_title": "Wavenumber / cm^-1",
            "factor": 8065.610420,
            "inverse": False,
        },
        "nm": {"xaxis_title": "Wavelength / nm", "factor": 1239.84193, "inverse": True},
    }

    def __init__(
        self, indices: list[int], energies: list[float], intensities: list[float]
    ):
        self.indices = indices
        self.energies = energies
        self.intensities = intensities
        return

    @classmethod
    def from_tddft_output(cls, output_file: str):
        indices = []
        energies = []
        foscs = []
        with open(output_file) as f:
            for line in f:
                if "ABSORPTION SPECTRUM VIA TRANSITION ELECTRIC DIPOLE MOMENTS" in line:
                    next(f)
                    next(f)
                    next(f)
                    next(f)
                    while True:
                        line = next(f)
                        if len(line.split()) == 0:
                            break
                        line = line.split()
                        fosc = float(line[3])
                        if fosc < 0.01:
                            continue
                        indices.append(int(line[0]))
                        energies.append(float(line[1]) / 8065.610420)
                        foscs.append(fosc)
        return cls(indices, energies, foscs)

    @staticmethod
    def get_impulse_line(
        x: list[float], y: list[float]
    ) -> tuple[list[float | None], list[float | None]]:
        new_x: list[float | None] = []
        new_y: list[float | None] = []
        for xx, yy in zip(x, y):
            new_x.append(xx)
            new_y.append(0.0)
            new_x.append(xx)
            new_y.append(yy)
            new_x.append(None)
            new_y.append(None)
        return new_x, new_y

    def get_energies_as(self, unit: str) -> list[float]:
        unit = unit.lower()
        assert unit in ["ev", "cm-1", "nm"], "Unit must be 'ev', 'cm-1' or 'nm'"
        if self.units[unit]["inverse"]:
            x = [self.units[unit]["factor"] / e for e in self.energies]
        else:
            x = [self.units[unit]["factor"] * e for e in self.energies]
        return x

    def get_line_spectrum(
        self, unit: str
    ) -> tuple[list[float | None], list[float | None]]:
        return Spectrum.get_impulse_line(self.get_energies_as(unit), self.intensities)