from __future__ import annotations
from abc import abstractmethod
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from itertools import chain
from json import dump
from typing import Any
"""
This module gives a Python interface to generate JSON for the
CrystalToolkitSceneComponent. To use, create a Scene whose contents can either be a
a list of any of the geometric primitives defined below (e.g. Spheres,
Cylinders, etc.) or can be another Scene. Then use scene_to_json() to convert
the Scene to the JSON format to pass to CrystalToolkitSceneComponent's data attribute.
"""
[docs]class Primitive:
    """
    A Mixin class for standard plottable primitive behavior.
    For now, this just enforces some basic mergeability.
    """
    positions: tuple
    @property
    @abstractmethod
    def key(self):
        raise NotImplementedError
[docs]    @classmethod
    def merge(cls, items):
        raise NotImplementedError 
    @property
    def bounding_box(self) -> list[list[float]]:
        x, y, z = zip(*self.positions)
        return [[min(x), min(y), min(z)], [max(x), max(y), max(z)]] 
[docs]@dataclass
class Scene:
    """
    A Scene is defined by its name (a string, does not have to be unique),
    and its contents (a list of geometric primitives or other Scenes).
    """
    name: str  # name for the scene, does not have to be unique
    contents: list = field(default_factory=list)
    origin: list[float] = field(default=(0, 0, 0))
    visible: bool = True
    lattice: list[list[float]] | None = None
    _meta: dict | None = None
    def __add__(self, other):
        """
        For convenience to combine multiple scenes.
        No good way to decide what origin to set for the new scene.
        :param other: another Scene
        :return:
        """
        return Scene(
            name=f"{self.name}_{other.name}",
            contents=self.contents + other.contents,
            origin=self.origin,
            visible=self.visible,
            lattice=self.lattice,
            _meta={self.name: self._meta, other.name: other._meta},
        )
    def _repr_mimebundle_(self, include=None, exclude=None):
        """
        Render Scenes using crystaltoolkit-extension for Jupyter Lab.
        """
        return {
            "application/vnd.mp.ctk+json": self.to_json(),
            "text/plain": repr(self),
        }
[docs]    def to_json(self):
        """
        Convert a Scene into JSON. It will implicitly assume all None values means
        that that attribute uses its default value, and so will be removed from
        the JSON to reduce the file size of the resulting JSON.
        Note that this function actually returns a Python dict, but in a format
        that can be converted to a JSON string using the standard library JSON
        encoder.
        :param scene: A Scene object
        :return: dict in a format that can be parsed by CrystalToolkitSceneComponent
        """
        merged_scene = Scene(
            name=self.name,
            contents=self.merge_primitives(self.contents),
            origin=self.origin,
            lattice=self.lattice,
        )
        def remove_defaults(scene_dict):
            """
            Reduce file size of JSON by removing any key which
            is just its default value.
            """
            trimmed_dict = {}
            for k, v in scene_dict.items():
                if isinstance(v, dict):
                    v = remove_defaults(v)
                elif isinstance(v, list):
                    trimmed_dict[k] = [
                        remove_defaults(item) if isinstance(item, dict) else item
                        for item in v
                    ]
                elif v is not None:
                    trimmed_dict[k] = v
            return trimmed_dict
        return remove_defaults(asdict(merged_scene)) 
[docs]    def to_plotly_json(self):
        """
        Easy way to allow Scene objects to be returned from callbacks.
        """
        return self.to_json() 
[docs]    def to(self, filename):
        """
        Write a Scene to a file. Can be opened by Jupyter Lab if
        Crystal Toolkit extension installed.
        :param filename: The filename (can include path),
        an extension will be set if not supplied.
        :return:
        """
        # TODO: find a way to keep the original MSONable object + scene generation options alongside
        if not filename.endswith(".ctk.json"):
            filename += ".ctk.json"
        with open(filename, "w") as f:
            dump(self.to_json(), f) 
    @property
    def bounding_box(self) -> list[list[float]]:
        """
        Returns the bounding box coordinates
        """
        if len(self.contents) > 0:
            min_list, max_list = zip(*[p.bounding_box for p in self.contents])
            min_x, min_y, min_z = map(min, list(zip(*min_list)))
            max_x, max_y, max_z = map(max, list(zip(*max_list)))
            return [[min_x, min_y, min_z], [max_x, max_y, max_z]]
        else:
            return [[0, 0, 0], [0, 0, 0]]
[docs]    @staticmethod
    def merge_primitives(primitives):
        """
        If primitives are of the same type but differ only in position, they
        are merged together. This is a small optimization, has not been benchmarked.
        :param primitives: list of primitives (Spheres, Cylinders, etc.)
        :return: list of primitives
        """
        mergeable = defaultdict(list)
        remainder = []
        for primitive in primitives:
            if isinstance(primitive, Scene):
                primitive.contents = Scene.merge_primitives(primitive.contents)
                remainder.append(primitive)
            elif isinstance(primitive, Primitive):
                mergeable[primitive.key].append(primitive)
            else:
                remainder.append(primitive)
        merged = [v[0].merge(v) for v in mergeable.values()]
        return merged + remainder  
[docs]@dataclass
class Spheres(Primitive):
    """
    Create a set of spheres. All spheres will have the same color, radius and
    segment size (if only drawing a section of a sphere).
    :param positions: This is a list of lists corresponding to the vector
    positions of the spheres.
    :param color: Sphere color as a hexadecimal string, e.g. #ff0000
    :param radius: The radius of the sphere, defaults to 1.
    :param phiStart: Start angle in radians if drawing only a section of the
    sphere, defaults to 0
    :param phiEnd: End angle in radians if drawing only a section of the
    sphere, defaults to 2*pi
    :param visible: If False, will hide the object by default.
    :param reference: name to reference the primitive for callback
    :param clickable: if true, allows this primitive to be clicked
    and trigger and event
    """
    positions: list[list[float]]
    _animate: list[list[float]] | None = None
    color: str | None = None
    radius: float | None = None
    phiStart: float | None = None
    phiEnd: float | None = None
    type: str = field(default="spheres", init=False)  # private field
    visible: bool | None = None
    tooltip: str | None = None
    clickable: bool = False
    reference: str | None = None
    _meta: Any = None
    @property
    def key(self):
        return f"sphere_{self.color}_{self.radius}_{self.phiStart}_{self.phiEnd}_{self.clickable}_{self.tooltip}"
[docs]    @classmethod
    def merge(cls, sphere_list):
        new_positions = list(
            chain.from_iterable([sphere.positions for sphere in sphere_list])
        )
        return cls(
            positions=new_positions,
            color=sphere_list[0].color,
            radius=sphere_list[0].radius,
            phiStart=sphere_list[0].phiStart,
            phiEnd=sphere_list[0].phiEnd,
            visible=sphere_list[0].visible,
            clickable=sphere_list[0].clickable,
            tooltip=sphere_list[0].tooltip,
        )  
[docs]@dataclass
class Ellipsoids(Primitive):
    """
    Create a set of ellipsoids. All ellipsoids will have the same color, radius and
    segment size (if only drawing a section of a ellipsoid).
    :param scale: This is the scale to apply to the x,y and z axis of the ellipsoid prior to rotation to the target axes
    :param positions: This is a list of lists corresponding to the vector
    positions of the ellipsoids.
    :param rotate_to: This is a list of vectors that specify the direction the major axis of the ellipsoid should point
        towards. The major axis is the z-axis: (0,0,1)
    :param color: Ellipsoid color as a hexadecimal string, e.g. #ff0000
    :param phiStart: Start angle in radians if drawing only a section of the
    ellipsoid, defaults to 0
    :param phiEnd: End angle in radians if drawing only a section of the
    ellipsoid, defaults to 2*pi
    :param visible: If False, will hide the object by default.
    :param reference: name to reference the primitive for callback
    :param clickable: if true, allows this primitive to be clicked
    and trigger and event
    """
    scale: list[float]
    positions: list[list[float]]
    rotate_to: list[list[float]]
    _animate: list[list[float]] | None = None
    color: str | None = None
    phiStart: float | None = None
    phiEnd: float | None = None
    type: str = field(default="ellipsoids", init=False)  # private field
    visible: bool | None = None
    clickable: bool = False
    reference: str | None = None
    _meta: Any = None
    @property
    def key(self):
        return f"ellipsoid_{self.color}_{self.scale}_{self.phiStart}_{self.phiEnd}"
[docs]    @classmethod
    def merge(cls, ellipsoid_list):
        new_positions = list(
            chain.from_iterable([ellipsoid.positions for ellipsoid in ellipsoid_list])
        )
        rotate_to = list(
            chain.from_iterable([ellipsoid.rotate_to for ellipsoid in ellipsoid_list])
        )
        new__animate = list(
            chain.from_iterable(
                [
                    ellipsoid._animate
                    for ellipsoid in ellipsoid_list
                    if ellipsoid._animate
                ]
            )
        )
        return cls(
            positions=new_positions,
            rotate_to=rotate_to,
            _animate=new__animate,
            scale=ellipsoid_list[0].scale,
            color=ellipsoid_list[0].color,
            phiStart=ellipsoid_list[0].phiStart,
            phiEnd=ellipsoid_list[0].phiEnd,
            visible=ellipsoid_list[0].visible,
        )  
[docs]@dataclass
class Cylinders(Primitive):
    """
    Create a set of cylinders. All cylinders will have the same color and
    radius.
    :param positionPairs: This is a list of pairs of lists corresponding to the
    start and end position of the cylinder.
    :param color: Cylinder color as a hexadecimal string, e.g. #ff0000
    :param radius: The radius of the cylinder, defaults to 1.
    :param visible: If False, will hide the object by default.
    :param reference: name to reference the primitive for callback
    :param clickable: if true, allows this primitive to be clicked
    and trigger and event
    """
    positionPairs: list[list[list[float]]]
    _animate: list[list[list[float]]] | None = None
    color: str | None = None
    radius: float | None = None
    type: str = field(default="cylinders", init=False)  # private field
    visible: bool | None = None
    tooltip: str | None = None
    clickable: bool = False
    reference: str | None = None
    _meta: Any = None
    @property
    def key(self):
        return f"cylinder_{self.color}_{self.radius}_{self.reference}"
[docs]    @classmethod
    def merge(cls, cylinder_list):
        new_positionPairs = list(
            chain.from_iterable([cylinder.positionPairs for cylinder in cylinder_list])
        )
        return cls(
            positionPairs=new_positionPairs,
            color=cylinder_list[0].color,
            radius=cylinder_list[0].radius,
            visible=cylinder_list[0].visible,
        ) 
    @property
    def bounding_box(self) -> list[list[float]]:
        x, y, z = zip(*chain.from_iterable(self.positionPairs))
        return [[min(x), min(y), min(z)], [min(x), min(y), min(z)]] 
[docs]@dataclass
class Cubes(Primitive):
    """
    Create a set of cubes. All cubes will have the same color and width.
    :param positions: This is a list of lists corresponding to the vector
    positions of the cubes.
    :param color: Cube color as a hexadecimal string, e.g. #ff0000
    :param width: The width of the cube, defaults to 1.
    :param visible: If False, will hide the object by default.
    :param reference: name to reference the primitive for callback
    :param clickable: if true, allows this primitive to be clicked
    and trigger and event
    """
    positions: list[list[float]]
    _animate: list[list[float]] | None = None
    color: str | None = None
    width: float | None = None
    type: str = field(default="cubes", init=False)  # private field
    visible: bool | None = None
    clickable: bool = False
    reference: str | None = None
    _meta: Any = None
    @property
    def key(self):
        return f"cube_{self.color}_{self.width}_{self.reference}"
[docs]    @classmethod
    def merge(cls, cube_list):
        new_positions = list(
            chain.from_iterable([cube.positions for cube in cube_list])
        )
        return cls(
            positions=new_positions,
            color=cube_list[0].color,
            width=cube_list[0].width,
            visible=cube_list[0].visible,
        )  
[docs]@dataclass
class Lines(Primitive):
    """
    Create a set of lines. All lines will have the same color, thickness and
    (optional) dashes.
    :param positions: This is a list of lists corresponding to the positions of
    the lines. Each consecutive pair of vectors corresponds to the start and end
    position of a line segment (line segments do not have to be joined
    together).
    :param color: Line color as a hexadecimal string, e.g. #ff0000
    :param linewidth: The width of the line, defaults to 1
    :param scale: Optional, if provided will set a global scale for line dashes.
    :param dashSize: Optional, if provided will specify length of line dashes.
    :param gapSize: Optional, if provided will specify gap between line dashes.
    :param visible: If False, will hide the object by default.
    :param reference: name to reference the primitive for callback
    :param clickable: if true, allows this primitive to be clicked
    and trigger and event
    """
    positions: list[list[float]]
    _animate: list[list[float]] | None = None
    color: str | None = None
    linewidth: float | None = None
    scale: float | None = None
    dashSize: float | None = None
    gapSize: float | None = None
    type: str = field(default="lines", init=False)  # private field
    visible: bool | None = None
    clickable: bool = False
    reference: str | None = None
    _meta: Any = None
    @property
    def key(self):
        return f"line_{self.color}_{self.linewidth}_{self.dashSize}_{self.gapSize}_{self.reference}"
[docs]    @classmethod
    def merge(cls, line_list):
        new_positions = list(
            chain.from_iterable([line.positions for line in line_list])
        )
        return cls(
            positions=new_positions,
            color=line_list[0].color,
            linewidth=line_list[0].linewidth,
            scale=line_list[0].scale,
            dashSize=line_list[0].dashSize,
            gapSize=line_list[0].gapSize,
            visible=line_list[0].visible,
        )  
[docs]@dataclass
class Surface:
    """
    Define a surface by its vertices. Please also provide normals if known.
    Opacity can be set to enable transparency, but note that the current
    Three.js renderer doesn't support nested transparent objects very well.
    """
    positions: list[list[float]]
    _animate: list[list[float]] | None = None
    normals: list[list[float]] | None = None
    color: str | None = None
    opacity: float | None = None
    show_edges: bool = False
    type: str = field(default="surface", init=False)  # private field
    visible: bool | None = None
    clickable: bool = False
    reference: str | None = None
    _meta: Any = None
    @property
    def bounding_box(self) -> list[list[float]]:
        # Not used in the calculation of the bounding box
        return [[0, 0, 0], [0, 0, 0]] 
[docs]@dataclass
class Convex:
    """
    Create a surface from the convex hull formed by list of points. Note that
    at least four points must be specified. The current Three.js renderer uses
    the QuickHull algorithm. Opacity can be set to enable transparency, but note
    that the current Three.js renderer doesn't support nested transparent
    objects very well.
    """
    positions: list[list[float]]
    _animate: list[list[float]] | None = None
    color: str | None = None
    opacity: float | None = None
    type: str = field(default="convex", init=False)  # private field
    visible: bool | None = None
    clickable: bool = False
    reference: str | None = None
    _meta: Any = None
    @property
    def bounding_box(self) -> list[list[float]]:
        # Not used in the calculation of the bounding box
        return [[0, 0, 0], [0, 0, 0]] 
[docs]@dataclass
class Arrows(Primitive):
    """
    Create a set of arrows. All arrows will have the same color radius and
    head shape.
    :param positionPairs: This is a list of pairs of lists corresponding to the
    start and end position of the cylinder.
    :param color: Cylinder color as a hexadecimal string, e.g. #ff0000
    :param radius: The radius of the cylinder, defaults to 1.
    :param visible: If False, will hide the object by default.
    :param reference: name to reference the primitive for callback
    :param clickable: if true, allows this primitive to be clicked
    and trigger and event
    """
    positionPairs: list[list[list[float]]]
    _animate: list[list[list[float]]] | None = None
    color: str | None = None
    radius: float | None = None
    headLength: float | None = None
    headWidth: float | None = None
    type: str = field(default="arrows", init=False)  # private field
    visible: bool | None = None
    clickable: bool = False
    reference: str | None = None
    _meta: Any = None
    @property
    def key(self):
        return f"arrow_{self.color}_{self.radius}_{self.headLength}_{self.headWidth}_{self.reference}"
[docs]    @classmethod
    def merge(cls, arrow_list):
        new_positionPairs = list(
            chain.from_iterable([arrow.positionPairs for arrow in arrow_list])
        )
        return cls(
            positionPairs=new_positionPairs,
            color=arrow_list[0].color,
            radius=arrow_list[0].radius,
            headLength=arrow_list[0].headLength,
            headWidth=arrow_list[0].headWidth,
            visible=arrow_list[0].visible,
        ) 
    @property
    def bounding_box(self) -> list[list[float]]:
        x, y, z = zip(*chain.from_iterable(self.positionPairs))
        return [[min(x), min(y), min(z)], [min(x), min(y), min(z)]] 
[docs]@dataclass
class Label:
    """
    Add a label to an object.
    """
    label: str
    labelHover: str | None = None
    position: list[list[float]] | None = None
    type: str = field(default="labels", init=False)  # private field
    visible: bool | None = None
    clickable: bool = False
    reference: str | None = None
    _meta: Any = None 
[docs]@dataclass
class Bezier:
    """
    A tube shaped by Bézier control points.
    """
    controlPoints: list[list[list[float]]] | None = None
    color: list[str] | None = None
    radius: list[float] | None = None
    type: str = field(default="bezier", init=False)  # private field
    visible: bool | None = None
    clickable: bool = False
    reference: str | None = None
    _meta: Any = None