from __future__ import annotations
from collections import defaultdict
from itertools import combinations
import numpy as np
from matplotlib.cm import get_cmap
from pymatgen.analysis.graphs import StructureGraph
from pymatgen.core.sites import PeriodicSite
from crystal_toolkit.core.legend import Legend
from crystal_toolkit.core.scene import Scene
def _get_sites_to_draw(
    self, draw_image_atoms=True, bonded_sites_outside_unit_cell=False
):
    """
    Returns a list of site indices and image vectors.
    """
    sites_to_draw = [(idx, (0, 0, 0)) for idx in range(len(self.structure))]
    if draw_image_atoms:
        for idx, site in enumerate(self.structure):
            zero_elements = [
                idx
                for idx, f in enumerate(site.frac_coords)
                if np.allclose(f, 0, atol=0.05)
            ]
            coord_permutations = [
                x
                for length in range(1, len(zero_elements) + 1)
                for x in combinations(zero_elements, length)
            ]
            for perm in coord_permutations:
                sites_to_draw.append(
                    (idx, (int(0 in perm), int(1 in perm), int(2 in perm)))
                )
            one_elements = [
                idx
                for idx, f in enumerate(site.frac_coords)
                if np.allclose(f, 1, atol=0.05)
            ]
            coord_permutations = [
                x
                for length in range(1, len(one_elements) + 1)
                for x in combinations(one_elements, length)
            ]
            for perm in coord_permutations:
                sites_to_draw.append(
                    (idx, (-int(0 in perm), -int(1 in perm), -int(2 in perm)))
                )
    if bonded_sites_outside_unit_cell:
        sites_to_append = []
        for (n, jimage) in sites_to_draw:
            connected_sites = self.get_connected_sites(n, jimage=jimage)
            for connected_site in connected_sites:
                if connected_site.jimage != (0, 0, 0):
                    sites_to_append.append(
                        (connected_site.index, connected_site.jimage)
                    )
        sites_to_draw += sites_to_append
    # remove any duplicate sites
    # (can happen when enabling bonded_sites_outside_unit_cell,
    #  since this works by following bonds, and a single site outside the
    #  unit cell can be bonded to multiple atoms within it)
    return set(sites_to_draw)
[docs]def get_structure_graph_scene(
    self,
    origin=None,
    draw_image_atoms=True,
    bonded_sites_outside_unit_cell=True,
    hide_incomplete_edges=False,
    incomplete_edge_length_scale=0.3,
    color_edges_by_edge_weight=False,
    edge_weight_color_scale="coolwarm",
    explicitly_calculate_polyhedra_hull=False,
    legend: Legend | None = None,
    group_by_site_property: str | None = None,
    bond_radius: float = 0.1,
) -> Scene:
    origin = origin or list(
        -self.structure.lattice.get_cartesian_coords([0.5, 0.5, 0.5])
    )
    legend = legend or Legend(self.structure)
    # we get primitives from each site individually, then
    # combine into one big Scene
    primitives = defaultdict(list)
    sites_to_draw = self._get_sites_to_draw(
        draw_image_atoms=draw_image_atoms,
        bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell,
    )
    color_edges = False
    if color_edges_by_edge_weight:
        weights = [e[2].get("weight") for e in self.graph.edges(data=True)]
        weights = np.array([w for w in weights if w])
        if any(weights):
            cmap = get_cmap(edge_weight_color_scale)
            # try to keep color scheme symmetric around 0
            weight_max = max([abs(min(weights)), max(weights)])
            weight_min = -weight_max
            def get_weight_color(weight):
                if not weight:
                    weight = 0
                x = (weight - weight_min) / (weight_max - weight_min)
                return "#{:02x}{:02x}{:02x}".format(
                    *[int(c * 255) for c in cmap(x)[0:3]]
                )
            color_edges = True
    if group_by_site_property:
        # we will create sub-scenes for each group of atoms
        # for example, if the Structure has a "wyckoff" site property
        # this might be used to allow grouping by Wyckoff position,
        # this then changes mouseover/interaction behavior with this scene
        grouped_atom_scene_contents = defaultdict(list)
    for (idx, jimage) in sites_to_draw:
        site = self.structure[idx]
        if jimage != (0, 0, 0):
            connected_sites = self.get_connected_sites(idx, jimage=jimage)
            site = PeriodicSite(
                site.species,
                np.add(site.frac_coords, jimage),
                site.lattice,
                properties=site.properties,
            )
        else:
            connected_sites = self.get_connected_sites(idx)
        connected_sites = [
            cs for cs in connected_sites if (cs.index, cs.jimage) in sites_to_draw
        ]
        connected_sites_not_drawn = [
            cs for cs in connected_sites if (cs.index, cs.jimage) not in sites_to_draw
        ]
        if color_edges:
            connected_sites_colors = [
                get_weight_color(cs.weight) for cs in connected_sites
            ]
            connected_sites_not_drawn_colors = [
                get_weight_color(cs.weight) for cs in connected_sites_not_drawn
            ]
        else:
            connected_sites_colors = None
            connected_sites_not_drawn_colors = None
        site_scene = site.get_scene(
            connected_sites=connected_sites,
            connected_sites_not_drawn=connected_sites_not_drawn,
            hide_incomplete_edges=hide_incomplete_edges,
            incomplete_edge_length_scale=incomplete_edge_length_scale,
            connected_sites_colors=connected_sites_colors,
            connected_sites_not_drawn_colors=connected_sites_not_drawn_colors,
            explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull,
            legend=legend,
            bond_radius=bond_radius,
        )
        for scene in site_scene.contents:
            if group_by_site_property and scene.name == "atoms":
                group_name = f"{site.properties[group_by_site_property]}"
                scene.contents[0].tooltip = group_name
                grouped_atom_scene_contents[group_name] += scene.contents
            else:
                primitives[scene.name] += scene.contents
    if group_by_site_property:
        atoms_scenes: list[Scene] = []
        for k, v in grouped_atom_scene_contents.items():
            atoms_scenes.append(Scene(name=k, contents=v))
        primitives["atoms"] = atoms_scenes
    primitives["unit_cell"].append(self.structure.lattice.get_scene())
    # why primitives comprehension? just make explicit! more readable
    return Scene(
        name="StructureGraph",
        origin=origin,
        contents=[
            Scene(name=k, contents=v, origin=origin) for k, v in primitives.items()
        ],
    ) 
StructureGraph._get_sites_to_draw = _get_sites_to_draw
StructureGraph.get_scene = get_structure_graph_scene