Source code for crystal_toolkit.core.scene

from __future__ import annotations

from abc import abstractmethod
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from itertools import chain
from json import dump
from typing import Any

"""
This module gives a Python interface to generate JSON for the
CrystalToolkitSceneComponent. To use, create a Scene whose contents can either be a
a list of any of the geometric primitives defined below (e.g. Spheres,
Cylinders, etc.) or can be another Scene. Then use scene_to_json() to convert
the Scene to the JSON format to pass to CrystalToolkitSceneComponent's data attribute.
"""


[docs]class Primitive: """ A Mixin class for standard plottable primitive behavior. For now, this just enforces some basic mergeability. """ positions: tuple @property @abstractmethod def key(self): raise NotImplementedError
[docs] @classmethod def merge(cls, items): raise NotImplementedError
@property def bounding_box(self) -> list[list[float]]: x, y, z = zip(*self.positions) return [[min(x), min(y), min(z)], [max(x), max(y), max(z)]]
[docs]@dataclass class Scene: """ A Scene is defined by its name (a string, does not have to be unique), and its contents (a list of geometric primitives or other Scenes). """ name: str # name for the scene, does not have to be unique contents: list = field(default_factory=list) origin: list[float] = field(default=(0, 0, 0)) visible: bool = True lattice: list[list[float]] | None = None _meta: dict | None = None def __add__(self, other): """ For convenience to combine multiple scenes. No good way to decide what origin to set for the new scene. :param other: another Scene :return: """ return Scene( name=f"{self.name}_{other.name}", contents=self.contents + other.contents, origin=self.origin, visible=self.visible, lattice=self.lattice, _meta={self.name: self._meta, other.name: other._meta}, ) def _repr_mimebundle_(self, include=None, exclude=None): """ Render Scenes using crystaltoolkit-extension for Jupyter Lab. """ return { "application/vnd.mp.ctk+json": self.to_json(), "text/plain": repr(self), }
[docs] def to_json(self): """ Convert a Scene into JSON. It will implicitly assume all None values means that that attribute uses its default value, and so will be removed from the JSON to reduce the file size of the resulting JSON. Note that this function actually returns a Python dict, but in a format that can be converted to a JSON string using the standard library JSON encoder. :param scene: A Scene object :return: dict in a format that can be parsed by CrystalToolkitSceneComponent """ merged_scene = Scene( name=self.name, contents=self.merge_primitives(self.contents), origin=self.origin, lattice=self.lattice, ) def remove_defaults(scene_dict): """ Reduce file size of JSON by removing any key which is just its default value. """ trimmed_dict = {} for k, v in scene_dict.items(): if isinstance(v, dict): v = remove_defaults(v) elif isinstance(v, list): trimmed_dict[k] = [ remove_defaults(item) if isinstance(item, dict) else item for item in v ] elif v is not None: trimmed_dict[k] = v return trimmed_dict return remove_defaults(asdict(merged_scene))
[docs] def to_plotly_json(self): """ Easy way to allow Scene objects to be returned from callbacks. """ return self.to_json()
[docs] def to(self, filename): """ Write a Scene to a file. Can be opened by Jupyter Lab if Crystal Toolkit extension installed. :param filename: The filename (can include path), an extension will be set if not supplied. :return: """ # TODO: find a way to keep the original MSONable object + scene generation options alongside if not filename.endswith(".ctk.json"): filename += ".ctk.json" with open(filename, "w") as f: dump(self.to_json(), f)
@property def bounding_box(self) -> list[list[float]]: """ Returns the bounding box coordinates """ if len(self.contents) > 0: min_list, max_list = zip(*[p.bounding_box for p in self.contents]) min_x, min_y, min_z = map(min, list(zip(*min_list))) max_x, max_y, max_z = map(max, list(zip(*max_list))) return [[min_x, min_y, min_z], [max_x, max_y, max_z]] else: return [[0, 0, 0], [0, 0, 0]]
[docs] @staticmethod def merge_primitives(primitives): """ If primitives are of the same type but differ only in position, they are merged together. This is a small optimization, has not been benchmarked. :param primitives: list of primitives (Spheres, Cylinders, etc.) :return: list of primitives """ mergeable = defaultdict(list) remainder = [] for primitive in primitives: if isinstance(primitive, Scene): primitive.contents = Scene.merge_primitives(primitive.contents) remainder.append(primitive) elif isinstance(primitive, Primitive): mergeable[primitive.key].append(primitive) else: remainder.append(primitive) merged = [v[0].merge(v) for v in mergeable.values()] return merged + remainder
[docs]@dataclass class Spheres(Primitive): """ Create a set of spheres. All spheres will have the same color, radius and segment size (if only drawing a section of a sphere). :param positions: This is a list of lists corresponding to the vector positions of the spheres. :param color: Sphere color as a hexadecimal string, e.g. #ff0000 :param radius: The radius of the sphere, defaults to 1. :param phiStart: Start angle in radians if drawing only a section of the sphere, defaults to 0 :param phiEnd: End angle in radians if drawing only a section of the sphere, defaults to 2*pi :param visible: If False, will hide the object by default. :param reference: name to reference the primitive for callback :param clickable: if true, allows this primitive to be clicked and trigger and event """ positions: list[list[float]] _animate: list[list[float]] | None = None color: str | None = None radius: float | None = None phiStart: float | None = None phiEnd: float | None = None type: str = field(default="spheres", init=False) # private field visible: bool | None = None tooltip: str | None = None clickable: bool = False reference: str | None = None _meta: Any = None @property def key(self): return f"sphere_{self.color}_{self.radius}_{self.phiStart}_{self.phiEnd}_{self.clickable}_{self.tooltip}"
[docs] @classmethod def merge(cls, sphere_list): new_positions = list( chain.from_iterable([sphere.positions for sphere in sphere_list]) ) return cls( positions=new_positions, color=sphere_list[0].color, radius=sphere_list[0].radius, phiStart=sphere_list[0].phiStart, phiEnd=sphere_list[0].phiEnd, visible=sphere_list[0].visible, clickable=sphere_list[0].clickable, tooltip=sphere_list[0].tooltip, )
[docs]@dataclass class Ellipsoids(Primitive): """ Create a set of ellipsoids. All ellipsoids will have the same color, radius and segment size (if only drawing a section of a ellipsoid). :param scale: This is the scale to apply to the x,y and z axis of the ellipsoid prior to rotation to the target axes :param positions: This is a list of lists corresponding to the vector positions of the ellipsoids. :param rotate_to: This is a list of vectors that specify the direction the major axis of the ellipsoid should point towards. The major axis is the z-axis: (0,0,1) :param color: Ellipsoid color as a hexadecimal string, e.g. #ff0000 :param phiStart: Start angle in radians if drawing only a section of the ellipsoid, defaults to 0 :param phiEnd: End angle in radians if drawing only a section of the ellipsoid, defaults to 2*pi :param visible: If False, will hide the object by default. :param reference: name to reference the primitive for callback :param clickable: if true, allows this primitive to be clicked and trigger and event """ scale: list[float] positions: list[list[float]] rotate_to: list[list[float]] _animate: list[list[float]] | None = None color: str | None = None phiStart: float | None = None phiEnd: float | None = None type: str = field(default="ellipsoids", init=False) # private field visible: bool | None = None clickable: bool = False reference: str | None = None _meta: Any = None @property def key(self): return f"ellipsoid_{self.color}_{self.scale}_{self.phiStart}_{self.phiEnd}"
[docs] @classmethod def merge(cls, ellipsoid_list): new_positions = list( chain.from_iterable([ellipsoid.positions for ellipsoid in ellipsoid_list]) ) rotate_to = list( chain.from_iterable([ellipsoid.rotate_to for ellipsoid in ellipsoid_list]) ) new__animate = list( chain.from_iterable( [ ellipsoid._animate for ellipsoid in ellipsoid_list if ellipsoid._animate ] ) ) return cls( positions=new_positions, rotate_to=rotate_to, _animate=new__animate, scale=ellipsoid_list[0].scale, color=ellipsoid_list[0].color, phiStart=ellipsoid_list[0].phiStart, phiEnd=ellipsoid_list[0].phiEnd, visible=ellipsoid_list[0].visible, )
[docs]@dataclass class Cylinders(Primitive): """ Create a set of cylinders. All cylinders will have the same color and radius. :param positionPairs: This is a list of pairs of lists corresponding to the start and end position of the cylinder. :param color: Cylinder color as a hexadecimal string, e.g. #ff0000 :param radius: The radius of the cylinder, defaults to 1. :param visible: If False, will hide the object by default. :param reference: name to reference the primitive for callback :param clickable: if true, allows this primitive to be clicked and trigger and event """ positionPairs: list[list[list[float]]] _animate: list[list[list[float]]] | None = None color: str | None = None radius: float | None = None type: str = field(default="cylinders", init=False) # private field visible: bool | None = None tooltip: str | None = None clickable: bool = False reference: str | None = None _meta: Any = None @property def key(self): return f"cylinder_{self.color}_{self.radius}_{self.reference}"
[docs] @classmethod def merge(cls, cylinder_list): new_positionPairs = list( chain.from_iterable([cylinder.positionPairs for cylinder in cylinder_list]) ) return cls( positionPairs=new_positionPairs, color=cylinder_list[0].color, radius=cylinder_list[0].radius, visible=cylinder_list[0].visible, )
@property def bounding_box(self) -> list[list[float]]: x, y, z = zip(*chain.from_iterable(self.positionPairs)) return [[min(x), min(y), min(z)], [min(x), min(y), min(z)]]
[docs]@dataclass class Cubes(Primitive): """ Create a set of cubes. All cubes will have the same color and width. :param positions: This is a list of lists corresponding to the vector positions of the cubes. :param color: Cube color as a hexadecimal string, e.g. #ff0000 :param width: The width of the cube, defaults to 1. :param visible: If False, will hide the object by default. :param reference: name to reference the primitive for callback :param clickable: if true, allows this primitive to be clicked and trigger and event """ positions: list[list[float]] _animate: list[list[float]] | None = None color: str | None = None width: float | None = None type: str = field(default="cubes", init=False) # private field visible: bool | None = None clickable: bool = False reference: str | None = None _meta: Any = None @property def key(self): return f"cube_{self.color}_{self.width}_{self.reference}"
[docs] @classmethod def merge(cls, cube_list): new_positions = list( chain.from_iterable([cube.positions for cube in cube_list]) ) return cls( positions=new_positions, color=cube_list[0].color, width=cube_list[0].width, visible=cube_list[0].visible, )
[docs]@dataclass class Lines(Primitive): """ Create a set of lines. All lines will have the same color, thickness and (optional) dashes. :param positions: This is a list of lists corresponding to the positions of the lines. Each consecutive pair of vectors corresponds to the start and end position of a line segment (line segments do not have to be joined together). :param color: Line color as a hexadecimal string, e.g. #ff0000 :param linewidth: The width of the line, defaults to 1 :param scale: Optional, if provided will set a global scale for line dashes. :param dashSize: Optional, if provided will specify length of line dashes. :param gapSize: Optional, if provided will specify gap between line dashes. :param visible: If False, will hide the object by default. :param reference: name to reference the primitive for callback :param clickable: if true, allows this primitive to be clicked and trigger and event """ positions: list[list[float]] _animate: list[list[float]] | None = None color: str | None = None linewidth: float | None = None scale: float | None = None dashSize: float | None = None gapSize: float | None = None type: str = field(default="lines", init=False) # private field visible: bool | None = None clickable: bool = False reference: str | None = None _meta: Any = None @property def key(self): return f"line_{self.color}_{self.linewidth}_{self.dashSize}_{self.gapSize}_{self.reference}"
[docs] @classmethod def merge(cls, line_list): new_positions = list( chain.from_iterable([line.positions for line in line_list]) ) return cls( positions=new_positions, color=line_list[0].color, linewidth=line_list[0].linewidth, scale=line_list[0].scale, dashSize=line_list[0].dashSize, gapSize=line_list[0].gapSize, visible=line_list[0].visible, )
[docs]@dataclass class Surface: """ Define a surface by its vertices. Please also provide normals if known. Opacity can be set to enable transparency, but note that the current Three.js renderer doesn't support nested transparent objects very well. """ positions: list[list[float]] _animate: list[list[float]] | None = None normals: list[list[float]] | None = None color: str | None = None opacity: float | None = None show_edges: bool = False type: str = field(default="surface", init=False) # private field visible: bool | None = None clickable: bool = False reference: str | None = None _meta: Any = None @property def bounding_box(self) -> list[list[float]]: # Not used in the calculation of the bounding box return [[0, 0, 0], [0, 0, 0]]
[docs]@dataclass class Convex: """ Create a surface from the convex hull formed by list of points. Note that at least four points must be specified. The current Three.js renderer uses the QuickHull algorithm. Opacity can be set to enable transparency, but note that the current Three.js renderer doesn't support nested transparent objects very well. """ positions: list[list[float]] _animate: list[list[float]] | None = None color: str | None = None opacity: float | None = None type: str = field(default="convex", init=False) # private field visible: bool | None = None clickable: bool = False reference: str | None = None _meta: Any = None @property def bounding_box(self) -> list[list[float]]: # Not used in the calculation of the bounding box return [[0, 0, 0], [0, 0, 0]]
[docs]@dataclass class Arrows(Primitive): """ Create a set of arrows. All arrows will have the same color radius and head shape. :param positionPairs: This is a list of pairs of lists corresponding to the start and end position of the cylinder. :param color: Cylinder color as a hexadecimal string, e.g. #ff0000 :param radius: The radius of the cylinder, defaults to 1. :param visible: If False, will hide the object by default. :param reference: name to reference the primitive for callback :param clickable: if true, allows this primitive to be clicked and trigger and event """ positionPairs: list[list[list[float]]] _animate: list[list[list[float]]] | None = None color: str | None = None radius: float | None = None headLength: float | None = None headWidth: float | None = None type: str = field(default="arrows", init=False) # private field visible: bool | None = None clickable: bool = False reference: str | None = None _meta: Any = None @property def key(self): return f"arrow_{self.color}_{self.radius}_{self.headLength}_{self.headWidth}_{self.reference}"
[docs] @classmethod def merge(cls, arrow_list): new_positionPairs = list( chain.from_iterable([arrow.positionPairs for arrow in arrow_list]) ) return cls( positionPairs=new_positionPairs, color=arrow_list[0].color, radius=arrow_list[0].radius, headLength=arrow_list[0].headLength, headWidth=arrow_list[0].headWidth, visible=arrow_list[0].visible, )
@property def bounding_box(self) -> list[list[float]]: x, y, z = zip(*chain.from_iterable(self.positionPairs)) return [[min(x), min(y), min(z)], [min(x), min(y), min(z)]]
[docs]@dataclass class Label: """ Add a label to an object. """ label: str labelHover: str | None = None position: list[list[float]] | None = None type: str = field(default="labels", init=False) # private field visible: bool | None = None clickable: bool = False reference: str | None = None _meta: Any = None
[docs]@dataclass class Bezier: """ A tube shaped by Bézier control points. """ controlPoints: list[list[list[float]]] | None = None color: list[str] | None = None radius: list[float] | None = None type: str = field(default="bezier", init=False) # private field visible: bool | None = None clickable: bool = False reference: str | None = None _meta: Any = None