from __future__ import annotations
from itertools import chain
import numpy as np
from pymatgen.analysis.graphs import ConnectedSite
from pymatgen.core.periodic_table import DummySpecie
from pymatgen.core.sites import Site
from pymatgen.electronic_structure.core import Magmom
from scipy.spatial.qhull import Delaunay
from crystal_toolkit.core.legend import Legend
from crystal_toolkit.core.scene import (
Arrows,
Convex,
Cubes,
Cylinders,
Scene,
Spheres,
Surface,
)
[docs]def get_site_scene(
self,
connected_sites: list[ConnectedSite] = None,
# connected_site_metadata: None,
# connected_sites_to_draw,
connected_sites_not_drawn: list[ConnectedSite] = None,
hide_incomplete_edges: bool = False,
incomplete_edge_length_scale: float | None = 1.0,
connected_sites_colors: list[str] | None = None,
connected_sites_not_drawn_colors: list[str] | None = None,
origin: list[float] | None = None,
draw_polyhedra: bool = True,
explicitly_calculate_polyhedra_hull: bool = False,
bond_radius: float = 0.1,
draw_magmoms: bool = True,
magmom_scale: float = 1.0,
legend: Legend | None = None,
) -> Scene:
"""
Args:
connected_sites:
connected_sites_not_drawn:
hide_incomplete_edges:
incomplete_edge_length_scale:
connected_sites_colors:
connected_sites_not_drawn_colors:
origin:
explicitly_calculate_polyhedra_hull:
legend:
Returns:
"""
atoms = []
bonds = []
polyhedron = []
magmoms = []
legend = legend or Legend(self)
# for disordered structures
is_ordered = self.is_ordered
phiStart, phiEnd = None, None
occu_start = 0.0
position = self.coords.tolist()
radii = [legend.get_radius(sp, site=self) for sp in self.species]
max_radius = float(min(radii))
for sp, occu in self.species.items():
if isinstance(sp, DummySpecie):
cube = Cubes(
positions=[position], color=legend.get_color(sp, site=self), width=0.4
)
atoms.append(cube)
else:
color = legend.get_color(sp, site=self)
radius = legend.get_radius(sp, site=self)
# TODO: make optional/default to None
# in disordered structures, we fractionally color-code spheres,
# drawing a sphere segment from phi_end to phi_start
# (think a sphere pie chart)
if not is_ordered:
phi_frac_end = occu_start + occu
phi_frac_start = occu_start
occu_start = phi_frac_end
phiStart = phi_frac_start * np.pi * 2
phiEnd = phi_frac_end * np.pi * 2
name = str(sp)
if occu != 1.0:
name += f" ({occu}% occupancy)"
name += f" ({position[0]:.3f}, {position[1]:.3f}, {position[2]:.3f})"
if self.properties:
for k, v in self.properties.items():
name += f" ({k} = {v})"
sphere = Spheres(
positions=[position],
color=color,
radius=radius,
phiStart=phiStart,
phiEnd=phiEnd,
clickable=True,
tooltip=name,
)
atoms.append(sphere)
# Add magmoms
if draw_magmoms:
if magmom := self.properties.get("magmom"):
# enforce type
magmom = np.array(Magmom(magmom).get_moment())
magmom = 2 * magmom_scale * max_radius * magmom
tail = np.array(position) - 0.5 * np.array(magmom)
head = np.array(position) + 0.5 * np.array(magmom)
arrow = Arrows(
positionPairs=[[tail, head]],
color="red",
radius=0.20,
headLength=0.5,
headWidth=0.4,
clickable=True,
)
magmoms.append(arrow)
if not is_ordered and not np.isclose(phiEnd, np.pi * 2):
# if site occupancy doesn't sum to 100%, cap sphere
sphere = Spheres(
positions=[position],
color="#ffffff",
radius=max_radius,
phiStart=phiEnd,
phiEnd=np.pi * 2,
)
atoms.append(sphere)
if connected_sites:
# TODO: more graceful solution here
# if ambiguous (disordered), re-use last color used
site_color = color
# TODO: can cause a bug if all vertices almost co-planar
# necessary to include center site in case it's outside polyhedra
all_positions = [self.coords]
for idx, connected_site in enumerate(connected_sites):
connected_position = connected_site.site.coords
bond_midpoint = np.add(position, connected_position) / 2
if connected_sites_colors:
color = connected_sites_colors[idx]
else:
color = site_color
cylinder = Cylinders(
positionPairs=[[position, bond_midpoint.tolist()]],
color=color,
radius=bond_radius,
)
bonds.append(cylinder)
all_positions.append(connected_position.tolist())
if connected_sites_not_drawn and not hide_incomplete_edges:
for idx, connected_site in enumerate(connected_sites_not_drawn):
connected_position = connected_site.site.coords
bond_midpoint = (
incomplete_edge_length_scale
* np.add(position, connected_position)
/ 2
)
if connected_sites_not_drawn_colors:
color = connected_sites_not_drawn_colors[idx]
else:
color = site_color
cylinder = Cylinders(
positionPairs=[[position, bond_midpoint.tolist()]],
color=color,
radius=bond_radius,
)
bonds.append(cylinder)
all_positions.append(connected_position.tolist())
# ensure intersecting polyhedra are not shown, defaults to choose by electronegativity
not_most_electro_negative = map(
lambda x: (x.site.specie < self.specie) or (x.site.specie == self.specie),
connected_sites,
)
all_positions = [list(p) for p in all_positions]
if (
draw_polyhedra
and len(connected_sites) > 3
and not connected_sites_not_drawn
and not any(not_most_electro_negative)
):
if explicitly_calculate_polyhedra_hull:
try:
# all_positions = [[0, 0, 0], [0, 0, 10], [0, 10, 0], [10, 0, 0]]
# gives...
# .convex_hull = [[2, 3, 0], [1, 3, 0], [1, 2, 0], [1, 2, 3]]
# .vertex_neighbor_vertices = [1, 2, 3, 2, 3, 0, 1, 3, 0, 1, 2, 0]
vertices_indices = Delaunay(all_positions).convex_hull
except Exception:
vertices_indices = []
vertices = [
all_positions[idx] for idx in chain.from_iterable(vertices_indices)
]
polyhedron = [Surface(positions=vertices, color=site_color)]
else:
polyhedron = [Convex(positions=all_positions, color=site_color)]
return Scene(
self.species_string,
[
Scene("atoms", contents=atoms),
Scene("bonds", contents=bonds),
Scene("polyhedra", contents=polyhedron),
Scene("magmoms", contents=magmoms),
],
origin=origin,
)
Site.get_scene = get_site_scene