from __future__ import annotations
from collections import defaultdict
from itertools import combinations
import numpy as np
from pymatgen.core.sites import PeriodicSite
from pymatgen.core.structure import Structure
from crystal_toolkit.core.legend import Legend
from crystal_toolkit.core.scene import Scene
def _get_sites_to_draw(self, draw_image_atoms=True):
    """
    Returns a list of site indices and image vectors.
    """
    sites_to_draw = [(idx, (0, 0, 0)) for idx in range(len(self))]
    if draw_image_atoms:
        for idx, site in enumerate(self):
            zero_elements = [
                idx
                for idx, f in enumerate(site.frac_coords)
                if np.allclose(f, 0, atol=0.05)
            ]
            coord_permutations = [
                x
                for tmp_ in range(1, len(zero_elements) + 1)
                for x in combinations(zero_elements, tmp_)
            ]
            for perm in coord_permutations:
                sites_to_draw.append(
                    (idx, (int(0 in perm), int(1 in perm), int(2 in perm)))
                )
            one_elements = [
                idx
                for idx, f in enumerate(site.frac_coords)
                if np.allclose(f, 1, atol=0.05)
            ]
            coord_permutations = [
                x
                for tmp_ in range(1, len(one_elements) + 1)
                for x in combinations(one_elements, tmp_)
            ]
            for perm in coord_permutations:
                sites_to_draw.append(
                    (idx, (-int(0 in perm), -int(1 in perm), -int(2 in perm)))
                )
    return set(sites_to_draw)
[docs]def get_structure_scene(
    self,
    origin: list[float] = None,
    legend: Legend | None = None,
    draw_image_atoms: bool = True,
) -> Scene:
    """
    Create CTK objects for the lattice and sties
    Args:
        self:  Structure object
        origin: fractional coordinate of the origin
        legend: Legend for the sites
        draw_image_atoms: If true draw image atoms that are just outside the
        periodic boundary
    Returns:
        CTK scene object to be rendered
    """
    origin = origin or list(-self.lattice.get_cartesian_coords([0.5, 0.5, 0.5]))
    legend = legend or Legend(self)
    primitives = defaultdict(list)
    sites_to_draw = self._get_sites_to_draw(draw_image_atoms=draw_image_atoms)
    for (idx, jimage) in sites_to_draw:
        site = self[idx]
        if jimage != (0, 0, 0):
            site = PeriodicSite(
                site.species,
                np.add(site.frac_coords, jimage),
                site.lattice,
                properties=site.properties,
            )
        site_scene = site.get_scene(legend=legend)
        for scene in site_scene.contents:
            primitives[scene.name] += scene.contents
    primitives["unit_cell"].append(self.lattice.get_scene())
    lattice_vectors = self.lattice.matrix.tolist()
    return Scene(
        name="Structure",
        origin=origin,
        lattice=lattice_vectors,
        contents=[
            Scene(name=k, contents=v, origin=origin) for k, v in primitives.items()
        ],
    ) 
Structure._get_sites_to_draw = _get_sites_to_draw
Structure.get_scene = get_structure_scene