Skip to content
Snippets Groups Projects
Commit 777d8f9d authored by moritz.buchhorn's avatar moritz.buchhorn
Browse files

Added a working trajectory viewer on the first try!

parent 00a64da8
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
import qc_support as qcs
import qc_parser as qcp
import support
```
%% Cell type:code id: tags:
``` python
support.XYZ.from_file("hcho-guess.xyz").get_molecular_viewer().show()
```
%% Cell type:code id: tags:
``` python
parsed = qcp.read_qc_file("./hcho-opt.out")
```
%% Cell type:code id: tags:
``` python
parsed["coordinates"]
coord_list = [support.XYZ(frame) for frame in parsed["coordinates"]]
trajectory = support.Trajectory(coord_list)
trajectory.get_molecular_viewer()
```
%% Cell type:code id: tags:
``` python
```
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 27 11:51:18 2024
@author: benjaminlear
Adapted by M.Buchhorn, thank you Ben!
"""
atom_colors = [
"#000000", # 0 Unknown
"#dddddd", # 1 H
"cyan", # 2 He
"violet", # 3 Li
"#B22222", # 4 Be
"beige", # 5 B
"#444444", # 6 C
"blue", # 7 N
"red", # 8 O
"green", # 9 F
"cyan", # 10 Ne
"#FF1493", # 11 Na
"#A52A2A", # 12 Mg
"#800080", # 13 Al
"#D2691E", # 14 Si
"#228B22", # 15 P
"yellow", # 16 S
"green", # 17 Cl
"cyan", # 18 Ar
"#0000CD", # 19 K
"#8B4513", # 20 Ca
"#8FBC8F", # 21 Sc
"#B8860B", # 22 Ti
"#4682B4", # 23 V
"#B22222", # 24 Cr
"#2E8B57", # 25 Mn
"#FFD700", # 26 Fe
"#DAA520", # 27 Co
"#A52A2A", # 28 Ni
"#4169E1", # 29 Cu
"#708090", # 30 Zn
"#C0C0C0", # 31 Ga
"#808000", # 32 Ge
"#00FF00", # 33 As
"#00CED1", # 34 Se
"#0000FF", # 35 Br
"cyan", # 36 Kr
"#8A2BE2", # 37 Rb
"#8B4513", # 38 Sr
"#9400D3", # 39 Y
"#FF4500", # 40 Zr
"#4682B4", # 41 Nb
"#B22222", # 42 Mo
"#FFD700", # 43 Tc
"#A52A2A", # 44 Ru
"#228B22", # 45 Rh
"#8B4513", # 46 Pd
"#00CED1", # 47 Ag
"#696969", # 48 Cd
"#C0C0C0", # 49 In
"#808000", # 50 Sn
"#228B22", # 51 Sb
"#D2691E", # 52 Te
"#00FF00", # 53 I
"cyan", # 54 Xe
"#8A2BE2", # 55 Cs
"#8B4513", # 56 Ba
"#FF4500", # 57 La
"#FF1493", # 58 Ce
"#9400D3", # 59 Pr
"#4682B4", # 60 Nd
"#DAA520", # 61 Pm
"#B22222", # 62 Sm
"#228B22", # 63 Eu
"#A52A2A", # 64 Gd
"#B8860B", # 65 Tb
"#8FBC8F", # 66 Dy
"#4682B4", # 67 Ho
"#DAA520", # 68 Er
"#B22222", # 69 Tm
"#228B22", # 70 Yb
"#FF4500", # 71 Lu
"#A52A2A", # 72 Hf
"#4682B4", # 73 Ta
"#DAA520", # 74 W
]
atom_symbols = [
"unk", # 0
"H", # 1
"He", # 2
"Li", # 3
"Be", # 4
"B", # 5
"C", # 6
"N", # 7
"O", # 8
"F", # 9
"Ne", # 10
"Na", # 11
"Mg", # 12
"Al", # 13
"Si", # 14
"P", # 15
"S", # 16
"Cl", # 17
"Ar", # 18
"K", # 19
"Ca", # 20
"Sc", # 21
"Ti", # 22
"V", # 23
"Cr", # 24
"Mn", # 25
"Fe", # 26
"Co", # 27
"Ni", # 28
"Cu", # 29
"Zn", # 30
"Ga", # 31
"Ge", # 32
"As", # 33
"Se", # 34
"Br", # 35
"Kr", # 36
"Rb", # 37
"Sr", # 38
"Y", # 39
"Zr", # 40
"Nb", # 41
"Mo", # 42
"Tc", # 43
"Ru", # 44
"Rh", # 45
"Pd", # 46
"Ag", # 47
"Cd", # 48
"In", # 49
"Sn", # 50
"Sb", # 51
"Te", # 52
"I", # 53
"Xe", # 54
"Cs", # 55
"Ba", # 56
"La", # 57
"Ce", # 58
"Pr", # 59
"Nd", # 60
"Pm", # 61
"Sm", # 62
"Eu", # 63
"Gd", # 64
"Tb", # 65
"Dy", # 66
"Ho", # 67
"Er", # 68
"Tm", # 69
"Yb", # 70
"Lu", # 71
"Hf", # 72
"Ta", # 73
"W", # 74
"Re", # 75
"Os", # 76
"Ir", # 77
"Pt", # 78
"Au", # 79
"Hg", # 80
"Tl", # 81
"Pb", # 82
"Bi", # 83
"Po", # 84
"At", # 85
"Rn", # 86
"Fr", # 87
"Ra", # 88
"Ac", # 89
"Th", # 90
"Pa", # 91
"U", # 92
"Np", # 93
"Pu", # 94
"Am", # 95
"Cm", # 96
"Bk", # 97
"Cf", # 98
"Es", # 99
"Fm", # 100
"Md", # 101
"No", # 102
"Lr", # 103
"Rf", # 104
"Db", # 105
"Sg", # 106
"Bh", # 107
"Hs", # 108
"Mt", # 109
"Ds", # 110
"Rg", # 111
"Cn", # 112
"Nh", # 113
"Fl", # 114
"Mc", # 115
"Lv", # 116
"Ts", # 117
"Og", # 118
]
vdw_radii = [
1.70, # 0 Unknown, same as Carbon
1.20, # 1 H
1.40, # 2 He
1.82, # 3 Li
1.53, # 4 Be
1.92, # 5 B
1.70, # 6 C
1.55, # 7 N
1.52, # 8 O
1.47, # 9 F
1.54, # 10 Ne
2.27, # 11 Na
1.73, # 12 Mg
1.84, # 13 Al
2.10, # 14 Si
1.80, # 15 P
1.80, # 16 S
1.75, # 17 Cl
1.88, # 18 Ar
2.75, # 19 K
2.31, # 20 Ca
2.11, # 21 Sc
2.00, # 22 Ti
2.00, # 23 V
2.00, # 24 Cr
2.00, # 25 Mn
2.00, # 26 Fe
2.00, # 27 Co
1.63, # 28 Ni
1.40, # 29 Cu
1.39, # 30 Zn
1.87, # 31 Ga
2.11, # 32 Ge
1.85, # 33 As
1.90, # 34 Se
1.85, # 35 Br
2.02, # 36 Kr
3.03, # 37 Rb
2.49, # 38 Sr
2.19, # 39 Y
2.06, # 40 Zr
2.00, # 41 Nb
2.00, # 42 Mo
2.00, # 43 Tc
2.00, # 44 Ru
2.00, # 45 Rh
2.00, # 46 Pd
1.72, # 47 Ag
1.58, # 48 Cd
1.93, # 49 In
2.17, # 50 Sn
2.00, # 51 Sb
2.06, # 52 Te
1.98, # 53 I
2.16, # 54 Xe
3.43, # 55 Cs
2.68, # 56 Ba
2.50, # 57 La
2.48, # 58 Ce
2.47, # 59 Pr
2.45, # 60 Nd
2.43, # 61 Pm
2.42, # 62 Sm
2.40, # 63 Eu
2.38, # 64 Gd
2.37, # 65 Tb
2.35, # 66 Dy
2.33, # 67 Ho
2.32, # 68 Er
2.30, # 69 Tm
2.28, # 70 Yb
2.27, # 71 Lu
2.16, # 72 Hf
2.09, # 73 Ta
2.02, # 74 W
]
vdw_radii_dict = dict(zip(atom_symbols, vdw_radii))
atom_colors_dict = dict(zip(atom_symbols, atom_colors))
\ No newline at end of file
import numpy as np
import plotly.graph_objects as go
import atom_properties as atomp
from itertools import accumulate
class XYZ:
sphere_mode = "ball"
def __init__(self, coords: list[dict[str, str|float]]):
self.xyz = np.array([[coord["x"], coord["y"], coord["z"]] for coord in coords])
self.elements = [coord["element"] for coord in coords]
@classmethod
def from_file(cls, file: str):
"""
Reads in files with the structure 'element x y z' and returns
a numpy array with the stored coordinates.
@param file: String. Path to the file to be read in.
@return: 2D numpy array. Contains the coordinates from the file at the
corresponding indices.
"""
with open(file) as f:
lines = f.readlines()
if len(lines[0].split()) == 1:
lines = lines[2:]
return cls([{"element": line.split()[0],
"x": float(line.split()[1]),
"y": float(line.split()[2]),
"z": float(line.split()[3])} for line in lines])
def get_bond_length(self, i: int, j: int) -> float:
"""
Returns the bond length between atoms i and j.
Args:
i (int): Index of the first atom.
j (int): Index of the second atom.
Returns:
float: The bond length between the two atoms.
"""
return float(np.linalg.norm(self.xyz[i] - self.xyz[j]))
def circle_coordinates(self, point: np.ndarray, v: np.ndarray, radius: float, num_points: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Generate cartesian coordinates of a circle in 3D space.
Args:
point (np.ndarray): A point on the circle.
v (np.ndarray): A vector perpendicular to the circle.
radius (float): The radius of the circle.
num_points (int): The number of points to generate on the circle.
Returns:
np.ndarray: A 2D array of shape (num_points, 3) containing the cartesian coordinates of the circle.
"""
# Normalize the vector v
v = v / np.linalg.norm(v)
# Create a vector u that is perpendicular to v
u = np.cross(v, [1, 0, 0]) if np.linalg.norm(np.cross(v, [1, 0, 0])) > 1e-6 else np.cross(v, [0, 1, 0])
u = u / np.linalg.norm(u)
# Create a vector w that is perpendicular to both u and v
w = np.cross(u, v)
# Generate the cartesian coordinates of the circle
theta = np.linspace(0, 2 * np.pi, num_points)
circle_coords = point + radius * (u * np.cos(theta)[:, np.newaxis] + w * np.sin(theta)[:, np.newaxis])
x, y, z = circle_coords[:, 0], circle_coords[:, 1], circle_coords[:, 2]
return x, y, z
def make_fibonacci_sphere(self, center: np.ndarray, radius: float = 0.1, resolution: int = 32) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Return cartesian coordinates of points evenly distributed on the surface of a sphere.
Args:
center (np.ndarray): The coordinates of the center of the sphere.
radius (float, optional): The radius of the sphere. Defaults to 0.1.
resolution (int, optional): The number of points to be generated. Defaults to 32.
Returns:
tuple[np.ndarray, np.ndarray, np.ndarray]: The cartesian coordinates of the points
on the surface of the sphere. The three arrays are the x, y and z coordinates
of the points.
"""
num_points = resolution
indices = np.arange(0, num_points, dtype=float) + 0.5
phi = np.arccos(1 - 2*indices/num_points)
theta = np.pi * (1 + 5**0.5) * indices
x = radius * np.sin(phi) * np.cos(theta) + center[0]
y = radius * np.sin(phi) * np.sin(theta) + center[1]
z = radius * np.cos(phi) + center[2]
return x, y, z
def make_cylinder(self, center_i: int, center_j: int, resolution: int = 32, radius: float = 0.1) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Return cartesian coordinates of points on the edges of a cylinder.
The cylinder is defined by the line between atoms i and j, and the radius is the given value. The number of points is determined by the resolution.
Args:
center_i (int): Index of the first atom that defines the cylinder.
center_j (int): Index of the second atom that defines the cylinder.
resolution (int, optional): The number of points to be generated. Defaults to 32.
radius (float, optional): The radius of the cylinder. Defaults to 0.1.
Returns:
tuple[np.ndarray, np.ndarray, np.ndarray]: The cartesian coordinates of the points on the edges of the cylinder. The three arrays are the x, y and z coordinates of the points.
"""
v = self.xyz[center_j] - self.xyz[center_i]
x1, y1, z1 = self.circle_coordinates(self.xyz[center_i], v, radius=radius, num_points=resolution//4)
x2, y2, z2 = self.circle_coordinates(self.xyz[center_j], v, radius=radius, num_points=resolution//4)
return np.concatenate((x1, x2)), np.concatenate((y1, y2)), np.concatenate((z1, z2))
def make_atom_mesh(self, i: int, resolution: int = 32) -> go.Mesh3d:
"""
Return a Mesh3d object that represents the atom at index i.
The atom is represented as a sphere with a radius depending on its van der Waals radius.
Args:
i (int): Index of the atom to be represented.
resolution (int, optional): The number of points to be generated. Defaults to 32.
Returns:
go.Mesh3d: A Mesh3d object that represents the atom.
"""
if self.sphere_mode == "vdw":
radius = atomp.vdw_radii_dict[self.elements[i]]
elif self.sphere_mode == "ball":
radius = atomp.vdw_radii_dict[self.elements[i]] * 0.2
x, y, z = self.make_fibonacci_sphere(self.xyz[i], radius=radius, resolution=resolution)
atom_mesh = go.Mesh3d(
x=x,
y=y,
z=z,
color=atomp.atom_colors_dict[self.elements[i]],
opacity=1,
alphahull=0,
name=f'{self.elements[i]}{i}', # label is too short to also show coordinates
hoverinfo='name', # Only show the name on hover
)
return atom_mesh
def make_bond_mesh(self, i: int, j: int, resolution: int = 32, radius: float = 0.1) -> go.Mesh3d:
"""
Return a Mesh3d object that represents a bond between atoms i and j.
The bond is represented as a cylinder with a radius depending on the given radius.
Args:
i (int): Index of the first atom in the bond.
j (int): Index of the second atom in the bond.
resolution (int, optional): The number of points to be generated. Defaults to 32.
radius (float, optional): The radius of the cylinder. Defaults to 0.1.
Returns:
go.Mesh3d: A Mesh3d object that represents the bond.
"""
x, y, z = self.make_cylinder(i, j, resolution=resolution, radius=radius)
bond_mesh = go.Mesh3d(
x=x,
y=y,
z=z,
color="#444444",
opacity=1,
alphahull=0,
# name=f'{self.elements[i]}-{self.elements[j]}',
hoverinfo='none', # No hover info at all
)
return bond_mesh
def get_molecular_mesh(self, resolution: int = 64, rel_cutoff: float = 0.5) -> go.Figure:
"""
Returns a list of Mesh3d objects that represent the molecular structure.
Args:
resolution (int, optional): The number of points to be generated. Defaults to 64.
rel_cutoff (float, optional): Bond are only shown if the interatomic distance is smaller than the sum of the atomic vdw radii multiplied with this parameter.
Defaults to 0.5.
Returns:
go.Figure: A list of Mesh3d objects that represent the molecular structure.
"""
mesh_list = []
# Add the bonds
for i in range(len(self.elements)):
for j in range(i+1, len(self.elements)):
bond_length = self.get_bond_length(i, j)
max_bond_length = rel_cutoff * (atomp.vdw_radii_dict[self.elements[i]] + atomp.vdw_radii_dict[self.elements[j]])
if bond_length < max_bond_length:
mesh_list.append(self.make_bond_mesh(i, j, resolution=resolution, radius=0.1))
# Add the atoms
for i in range(len(self.elements)):
mesh_list.append(self.make_atom_mesh(i, resolution=resolution))
return mesh_list
def get_molecular_viewer(self, resolution: int = 64, rel_cutoff: float = 0.5) -> go.Figure:
"""
Returns a plotly figure with the molecular structure.
The molecular structure is represented as a collection of spheres at the atomic positions,
with the bonds represented as cylinders between the atoms.
Args:
resolution (int, optional): The number of points to be generated. Defaults to 64.
rel_cutoff (float, optional): Bond are only shown if the interatomic distance is smaller than the sum of the atomic vdw radii multiplied with this parameter.
Defaults to 0.5.
Returns:
go.Figure: A plotly figure with the molecular structure.
"""
fig = go.Figure()
for mesh in self.get_molecular_mesh(resolution=resolution, rel_cutoff=rel_cutoff):
fig.add_trace(mesh)
# Fix aspect ratio such that the molecule is displayed undistorted
fig.update_layout(scene_aspectmode='data')
# Remove axes and labels
# fig.update_scenes(
# xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, title=""),
# yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, title=""),
# zaxis=dict(showgrid=False, zeroline=False, showticklabels=False, title=""),
# )
return fig
class Trajectory:
def __init__(self, coordinates: list[XYZ]):
self.coordinates = coordinates
def get_molecular_viewer(self, resolution: int = 64, rel_cutoff: float = 0.5) -> go.Figure:
"""
Returns a plotly figure with the molecular structure of the trajectory.
Args:
resolution (int, optional): The number of points to be generated. Defaults to 64.
rel_cutoff (float, optional): Bond are only shown if the interatomic distance is smaller than the sum of the atomic vdw radii multiplied with this parameter.
Defaults to 0.5.
Returns:
go.Figure: A plotly figure with the molecular structure.
"""
# Get and add all the meshes
mesh_elements = [0]
fig = go.Figure()
for xyz in self.coordinates:
meshes = xyz.get_molecular_mesh(resolution=resolution, rel_cutoff=rel_cutoff)
mesh_elements.append(len(meshes))
for mesh in meshes:
fig.add_trace(mesh)
mesh_elements_cumulative = list(accumulate(mesh_elements))
# Make the visibility list
slider_steps: list[dict] = []
for i in range(len(self.coordinates)):
step = {
"method": "restyle",
"args": [{"visible": [False] * mesh_elements_cumulative[-1]}]
}
step["args"][0]["visible"][mesh_elements_cumulative[i]:mesh_elements_cumulative[i+1]] = [True] * mesh_elements[i]
slider_steps.append(step)
# Make the slider
slider = {
"active": 0,
"currentvalue": {"prefix": "Frame: "},
"steps": slider_steps
}
fig.update_layout(
sliders=[slider]
)
return fig
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment