from __future__ import annotations
from collections import defaultdict
from itertools import combinations
import numpy as np
from matplotlib.cm import get_cmap
from pymatgen.analysis.graphs import StructureGraph
from pymatgen.core.sites import PeriodicSite
from crystal_toolkit.core.legend import Legend
from crystal_toolkit.core.scene import Scene
def _get_sites_to_draw(
self, draw_image_atoms=True, bonded_sites_outside_unit_cell=False
):
"""
Returns a list of site indices and image vectors.
"""
sites_to_draw = [(idx, (0, 0, 0)) for idx in range(len(self.structure))]
if draw_image_atoms:
for idx, site in enumerate(self.structure):
zero_elements = [
idx
for idx, f in enumerate(site.frac_coords)
if np.allclose(f, 0, atol=0.05)
]
coord_permutations = [
x
for length in range(1, len(zero_elements) + 1)
for x in combinations(zero_elements, length)
]
for perm in coord_permutations:
sites_to_draw.append(
(idx, (int(0 in perm), int(1 in perm), int(2 in perm)))
)
one_elements = [
idx
for idx, f in enumerate(site.frac_coords)
if np.allclose(f, 1, atol=0.05)
]
coord_permutations = [
x
for length in range(1, len(one_elements) + 1)
for x in combinations(one_elements, length)
]
for perm in coord_permutations:
sites_to_draw.append(
(idx, (-int(0 in perm), -int(1 in perm), -int(2 in perm)))
)
if bonded_sites_outside_unit_cell:
sites_to_append = []
for (n, jimage) in sites_to_draw:
connected_sites = self.get_connected_sites(n, jimage=jimage)
for connected_site in connected_sites:
if connected_site.jimage != (0, 0, 0):
sites_to_append.append(
(connected_site.index, connected_site.jimage)
)
sites_to_draw += sites_to_append
# remove any duplicate sites
# (can happen when enabling bonded_sites_outside_unit_cell,
# since this works by following bonds, and a single site outside the
# unit cell can be bonded to multiple atoms within it)
return set(sites_to_draw)
[docs]def get_structure_graph_scene(
self,
origin=None,
draw_image_atoms=True,
bonded_sites_outside_unit_cell=True,
hide_incomplete_edges=False,
incomplete_edge_length_scale=0.3,
color_edges_by_edge_weight=False,
edge_weight_color_scale="coolwarm",
explicitly_calculate_polyhedra_hull=False,
legend: Legend | None = None,
group_by_site_property: str | None = None,
bond_radius: float = 0.1,
) -> Scene:
origin = origin or list(
-self.structure.lattice.get_cartesian_coords([0.5, 0.5, 0.5])
)
legend = legend or Legend(self.structure)
# we get primitives from each site individually, then
# combine into one big Scene
primitives = defaultdict(list)
sites_to_draw = self._get_sites_to_draw(
draw_image_atoms=draw_image_atoms,
bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell,
)
color_edges = False
if color_edges_by_edge_weight:
weights = [e[2].get("weight") for e in self.graph.edges(data=True)]
weights = np.array([w for w in weights if w])
if any(weights):
cmap = get_cmap(edge_weight_color_scale)
# try to keep color scheme symmetric around 0
weight_max = max([abs(min(weights)), max(weights)])
weight_min = -weight_max
def get_weight_color(weight):
if not weight:
weight = 0
x = (weight - weight_min) / (weight_max - weight_min)
return "#{:02x}{:02x}{:02x}".format(
*[int(c * 255) for c in cmap(x)[0:3]]
)
color_edges = True
if group_by_site_property:
# we will create sub-scenes for each group of atoms
# for example, if the Structure has a "wyckoff" site property
# this might be used to allow grouping by Wyckoff position,
# this then changes mouseover/interaction behavior with this scene
grouped_atom_scene_contents = defaultdict(list)
for (idx, jimage) in sites_to_draw:
site = self.structure[idx]
if jimage != (0, 0, 0):
connected_sites = self.get_connected_sites(idx, jimage=jimage)
site = PeriodicSite(
site.species,
np.add(site.frac_coords, jimage),
site.lattice,
properties=site.properties,
)
else:
connected_sites = self.get_connected_sites(idx)
connected_sites = [
cs for cs in connected_sites if (cs.index, cs.jimage) in sites_to_draw
]
connected_sites_not_drawn = [
cs for cs in connected_sites if (cs.index, cs.jimage) not in sites_to_draw
]
if color_edges:
connected_sites_colors = [
get_weight_color(cs.weight) for cs in connected_sites
]
connected_sites_not_drawn_colors = [
get_weight_color(cs.weight) for cs in connected_sites_not_drawn
]
else:
connected_sites_colors = None
connected_sites_not_drawn_colors = None
site_scene = site.get_scene(
connected_sites=connected_sites,
connected_sites_not_drawn=connected_sites_not_drawn,
hide_incomplete_edges=hide_incomplete_edges,
incomplete_edge_length_scale=incomplete_edge_length_scale,
connected_sites_colors=connected_sites_colors,
connected_sites_not_drawn_colors=connected_sites_not_drawn_colors,
explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull,
legend=legend,
bond_radius=bond_radius,
)
for scene in site_scene.contents:
if group_by_site_property and scene.name == "atoms":
group_name = f"{site.properties[group_by_site_property]}"
scene.contents[0].tooltip = group_name
grouped_atom_scene_contents[group_name] += scene.contents
else:
primitives[scene.name] += scene.contents
if group_by_site_property:
atoms_scenes: list[Scene] = []
for k, v in grouped_atom_scene_contents.items():
atoms_scenes.append(Scene(name=k, contents=v))
primitives["atoms"] = atoms_scenes
primitives["unit_cell"].append(self.structure.lattice.get_scene())
# why primitives comprehension? just make explicit! more readable
return Scene(
name="StructureGraph",
origin=origin,
contents=[
Scene(name=k, contents=v, origin=origin) for k, v in primitives.items()
],
)
StructureGraph._get_sites_to_draw = _get_sites_to_draw
StructureGraph.get_scene = get_structure_graph_scene