from __future__ import annotations
import re
import warnings
from base64 import b64encode
from collections import OrderedDict
from itertools import chain, combinations_with_replacement
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Literal
import numpy as np
from dash import dash_table as dt
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate
from dash_mp_components import CrystalToolkitScene
from emmet.core.settings import EmmetSettings
from pymatgen.analysis.graphs import MoleculeGraph, StructureGraph
from pymatgen.analysis.local_env import NearNeighbors
from pymatgen.core.composition import Composition
from pymatgen.core.periodic_table import DummySpecie
from pymatgen.core.structure import Molecule, Structure
from pymatgen.io.vasp.sets import MPRelaxSet
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from crystal_toolkit.core.legend import Legend
from crystal_toolkit.core.mpcomponent import MPComponent
from crystal_toolkit.core.scene import Scene
from crystal_toolkit.helpers.layouts import H2, Field, dcc, html
from crystal_toolkit.settings import SETTINGS
# TODO: make dangling bonds "stubs"? (fixed length)
DEFAULTS: dict[str, str | bool] = {
"color_scheme": "VESTA",
"bonding_strategy": "CrystalNN",
"radius_strategy": "uniform",
"draw_image_atoms": True,
"bonded_sites_outside_unit_cell": False,
"hide_incomplete_bonds": True,
"show_compass": True,
"unit_cell_choice": "input",
"show_legend": True,
"show_settings": True,
"show_controls": True,
"show_expand_button": True,
"show_image_button": True,
"show_export_button": True,
"show_position_button": True,
}
[docs]class StructureMoleculeComponent(MPComponent):
"""
A component to display pymatgen Structure, Molecule, StructureGraph
and MoleculeGraph objects.
"""
available_bonding_strategies = {
subclass.__name__: subclass for subclass in NearNeighbors.__subclasses__()
}
default_scene_settings = {
"extractAxis": True,
# For visual diff testing, we change the renderer
# to SVG since this WebGL support is more difficult
# in headless browsers / CI.
"renderer": "svg" if SETTINGS.TEST_MODE else "webgl",
"secondaryObjectView": False,
}
# what to show for the title_layout if structure/molecule not loaded
default_title = "Crystal Toolkit"
# human-readable label to file extension
# downloading Molecules has not yet been added
download_options = {
"Structure": {
"CIF (Symmetrized)": {"fmt": "cif", "symprec": EmmetSettings().SYMPREC},
"CIF": {"fmt": "cif"},
"POSCAR": {"fmt": "poscar"},
"JSON": {"fmt": "json"},
"Prismatic": {"fmt": "prismatic"},
"VASP Input Set (MPRelaxSet)": {}, # special
}
}
def __init__(
self,
struct_or_mol: None
| (Structure | StructureGraph | Molecule | MoleculeGraph) = None,
id: str = None,
className: str = "box",
scene_additions: Scene | None = None,
bonding_strategy: str = DEFAULTS["bonding_strategy"],
bonding_strategy_kwargs: dict | None = None,
color_scheme: str = DEFAULTS["color_scheme"],
color_scale: str | None = None,
radius_strategy: str = DEFAULTS["radius_strategy"],
unit_cell_choice: str = DEFAULTS["unit_cell_choice"],
draw_image_atoms: bool = DEFAULTS["draw_image_atoms"],
bonded_sites_outside_unit_cell: bool = DEFAULTS[
"bonded_sites_outside_unit_cell"
],
hide_incomplete_bonds: bool = DEFAULTS["hide_incomplete_bonds"],
show_compass: bool = DEFAULTS["show_compass"],
scene_settings: dict | None = None,
group_by_site_property: str | None = None,
show_legend: bool = DEFAULTS["show_legend"],
show_settings: bool = DEFAULTS["show_settings"],
show_controls: bool = DEFAULTS["show_controls"],
show_expand_button: bool = DEFAULTS["show_expand_button"],
show_image_button: bool = DEFAULTS["show_image_button"],
show_export_button: bool = DEFAULTS["show_export_button"],
show_position_button: bool = DEFAULTS["show_position_button"],
**kwargs,
):
"""
Create a StructureMoleculeComponent from a structure or molecule.
:param struct_or_mol: input structure or molecule
:param id: canonical id
:param scene_additions: extra geometric elements to add to the 3D scene
:param bonding_strategy: bonding strategy from pymatgen NearNeighbors class
:param bonding_strategy_kwargs: options for the bonding strategy
:param color_scheme: color scheme, see Legend class
:param color_scale: color scale, see Legend class
:param radius_strategy: radius strategy, see Legend class
:param draw_image_atoms: whether to draw repeats of atoms on periodic images
:param bonded_sites_outside_unit_cell: whether to draw sites bonded outside the unit cell
:param hide_incomplete_bonds: whether to hide or show incomplete bonds
:param show_compass: whether to hide or show the compass
:param scene_settings: scene settings (lighting etc.) to pass to CrystalToolkitScene
:param group_by_site_property: a site property used for grouping of atoms for mouseover/interaction,
:param show_legend: show or hide legend panel within the scene
:param show_controls: show or hide scene control bar
:param show_expand_button: show or hide the full screen button within the scene control bar
:param show_image_button: show or hide the image download button within the scene control bar
:param show_export_button: show or hide the file export button within the scene control bar
:param show_position_button: show or hide the revert position button within the scene control bar
e.g. Wyckoff label
:param kwargs: extra keyword arguments to pass to MPComponent
"""
super().__init__(id=id, default_data=struct_or_mol, **kwargs)
self.className = className
self.show_legend = show_legend
self.show_settings = show_settings
self.show_controls = show_controls
self.show_expand_button = show_expand_button
self.show_image_button = show_image_button
self.show_export_button = show_export_button
self.show_position_button = show_position_button
self.initial_scene_settings = self.default_scene_settings.copy()
if scene_settings:
self.initial_scene_settings.update(scene_settings)
self.create_store("scene_settings", initial_data=self.initial_scene_settings)
# unit cell choice and bonding algorithms need to come from a settings
# object (in a dcc.Store) guaranteed to be present in layout, rather
# than from the controls themselves -- since these are optional and
# may not be present in the layout
self.create_store(
"graph_generation_options",
initial_data={
"bonding_strategy": bonding_strategy,
"bonding_strategy_kwargs": bonding_strategy_kwargs,
"unit_cell_choice": unit_cell_choice,
},
)
self.create_store(
"display_options",
initial_data={
"color_scheme": color_scheme,
"color_scale": color_scale,
"radius_strategy": radius_strategy,
"draw_image_atoms": draw_image_atoms,
"bonded_sites_outside_unit_cell": bonded_sites_outside_unit_cell,
"hide_incomplete_bonds": hide_incomplete_bonds,
"show_compass": show_compass,
"group_by_site_property": group_by_site_property,
},
)
if scene_additions:
initial_scene_additions = Scene(
name="scene_additions", contents=scene_additions
).to_json()
else:
initial_scene_additions = None
self.create_store("scene_additions", initial_data=initial_scene_additions)
if struct_or_mol:
# graph is cached explicitly, this isn't necessary but is an
# optimization so that graph is only re-generated if bonding
# algorithm changes
struct_or_mol = self._preprocess_structure(
struct_or_mol, unit_cell_choice=unit_cell_choice
)
graph = self._preprocess_input_to_graph(
struct_or_mol,
bonding_strategy=bonding_strategy,
bonding_strategy_kwargs=bonding_strategy_kwargs,
)
scene, legend = self.get_scene_and_legend(
graph,
scene_additions=self.initial_data["scene_additions"],
**self.initial_data["display_options"],
)
if hasattr(struct_or_mol, "lattice"):
self._lattice = struct_or_mol.lattice
else:
# component could be initialized without a structure, in which case
# an empty scene should be displayed
graph = None
scene, legend = self.get_scene_and_legend(
None,
scene_additions=self.initial_data["scene_additions"],
**self.initial_data["display_options"],
)
self.create_store("legend_data", initial_data=legend)
self.create_store("graph", initial_data=graph)
# this is used by a CrystalToolkitScene component, not a dcc.Store
self._initial_data["scene"] = scene
# hide axes inset for molecules
if isinstance(struct_or_mol, Molecule) or isinstance(
struct_or_mol, MoleculeGraph
):
self.scene_kwargs = {"axisView": "HIDDEN"}
else:
self.scene_kwargs = {}
[docs] def generate_callbacks(self, app, cache):
# a lot of the verbosity in this callback is to support custom bonding
# this is not the format CutOffDictNN expects (since that is not JSON
# serializable), so we store as a list of tuples instead
# TODO: make CutOffDictNN args JSON serializable
app.clientside_callback(
"""
function (bonding_strategy, custom_cutoffs_rows, unit_cell_choice) {
const bonding_strategy_kwargs = {}
if (bonding_strategy === 'CutOffDictNN') {
const cut_off_dict = []
custom_cutoffs_rows.forEach(function(row) {
cut_off_dict.push([row['A'], row['B'], parseFloat(row['A—B'])])
})
bonding_strategy_kwargs.cut_off_dict = cut_off_dict
}
return {
bonding_strategy: bonding_strategy,
bonding_strategy_kwargs: bonding_strategy_kwargs,
unit_cell_choice: unit_cell_choice
}
}
""",
Output(self.id("graph_generation_options"), "data"),
[
Input(self.id("bonding_algorithm"), "value"),
Input(self.id("bonding_algorithm_custom_cutoffs"), "data"),
Input(self.id("unit-cell-choice"), "value"),
],
)
app.clientside_callback(
"""
function (values, options) {
const visibility = {}
options.forEach(function (opt) {
visibility[opt.value] = Boolean(values.includes(opt.value))
})
return visibility
}
""",
Output(self.id("scene"), "toggleVisibility"),
[Input(self.id("hide-show"), "value")],
[State(self.id("hide-show"), "options")],
)
app.clientside_callback(
"""
function (colorScheme, radiusStrategy, drawOptions, displayOptions) {
const newDisplayOptions = {...displayOptions}
newDisplayOptions.color_scheme = colorScheme
newDisplayOptions.radius_strategy = radiusStrategy
newDisplayOptions.draw_image_atoms = drawOptions.includes('draw_image_atoms')
newDisplayOptions.bonded_sites_outside_unit_cell = drawOptions.includes(
'bonded_sites_outside_unit_cell'
)
newDisplayOptions.hide_incomplete_bonds = drawOptions.includes('hide_incomplete_bonds')
return newDisplayOptions
}
""",
Output(self.id("display_options"), "data"),
[
Input(self.id("color-scheme"), "value"),
Input(self.id("radius_strategy"), "value"),
Input(self.id("draw_options"), "value"),
],
[State(self.id("display_options"), "data")],
)
@app.callback(
Output(self.id("graph"), "data"),
[
Input(self.id("graph_generation_options"), "data"),
Input(self.id(), "data"),
],
[State(self.id("graph"), "data")],
)
@cache.memoize()
def update_graph(graph_generation_options, struct_or_mol, current_graph):
if not struct_or_mol:
raise PreventUpdate
struct_or_mol = self.from_data(struct_or_mol)
current_graph = self.from_data(current_graph)
bonding_strategy_kwargs = graph_generation_options[
"bonding_strategy_kwargs"
]
# TODO: add additional check here?
unit_cell_choice = graph_generation_options["unit_cell_choice"]
struct_or_mol = self._preprocess_structure(struct_or_mol, unit_cell_choice)
graph = self._preprocess_input_to_graph(
struct_or_mol,
bonding_strategy=graph_generation_options["bonding_strategy"],
bonding_strategy_kwargs=bonding_strategy_kwargs,
)
if (
current_graph
and graph.structure == current_graph.structure
and graph == current_graph
):
raise PreventUpdate
return graph
@app.callback(
Output(self.id("scene"), "data"),
[
Input(self.id("graph"), "data"),
Input(self.id("display_options"), "data"),
Input(self.id("scene_additions"), "data"),
],
)
@cache.memoize()
def update_scene(graph, display_options, scene_additions):
if not graph or not display_options:
raise PreventUpdate
display_options = self.from_data(display_options)
graph = self.from_data(graph)
scene, legend = self.get_scene_and_legend(
graph, **display_options, scene_additions=scene_additions
)
return scene
@app.callback(
Output(self.id("legend_data"), "data"),
[
Input(self.id("graph"), "data"),
Input(self.id("display_options"), "data"),
Input(self.id("scene_additions"), "data"),
],
)
@cache.memoize()
def update_legend_and_colors(graph, display_options, scene_additions):
if not graph or not display_options:
raise PreventUpdate
display_options = self.from_data(display_options)
graph = self.from_data(graph)
scene, legend = self.get_scene_and_legend(
graph, **display_options, scene_additions=scene_additions
)
return legend
@app.callback(
Output(self.id("color-scheme"), "options"),
[Input(self.id("legend_data"), "data")],
)
def update_color_options(legend_data):
# TODO: make client-side
color_options = [
{"label": "Jmol", "value": "Jmol"},
{"label": "VESTA", "value": "VESTA"},
{"label": "Accessible", "value": "accessible"},
]
if not legend_data:
return color_options
for option in legend_data["available_color_schemes"]:
color_options += [
{"label": f"Site property: {option}", "value": option}
]
return color_options
# app.clientside_callback(
# """
# function (legendData) {
#
# var colorOptions = [
# {label: "Jmol", value: "Jmol"},
# {label: "VESTA", value: "VESTA"},
# {label: "Accessible", value: "accessible"},
# ]
#
#
#
# return colorOptions
# }
# """,
# Output(self.id("color-scheme"), "options"),
# [Input(self.id("legend_data"), "data")]
# )
@app.callback(
Output(self.id("download-image"), "data"),
Input(self.id("scene"), "imageDataTimestamp"),
[
State(self.id("scene"), "imageData"),
State(self.id(), "data"),
],
)
def download_image(image_data_timestamp, image_data, data):
if not image_data_timestamp:
raise PreventUpdate
struct_or_mol = self.from_data(data)
if isinstance(struct_or_mol, StructureGraph):
formula = struct_or_mol.structure.composition.reduced_formula
elif isinstance(struct_or_mol, MoleculeGraph):
formula = struct_or_mol.molecule.composition.reduced_Formula
else:
formula = struct_or_mol.composition.reduced_formula
if hasattr(struct_or_mol, "get_space_group_info"):
spgrp = struct_or_mol.get_space_group_info()[0]
else:
spgrp = ""
request_filename = f"{formula}-{spgrp}-crystal-toolkit.png"
return {
"content": image_data[len("data:image/png;base64,") :],
"filename": request_filename,
"base64": True,
"type": "image/png",
}
@app.callback(
Output(self.id("download-structure"), "data"),
Input(self.id("scene"), "fileTimestamp"),
[
State(self.id("scene"), "fileType"),
State(self.id(), "data"),
],
)
def download_structure(file_timestamp, download_option, data):
if not file_timestamp:
raise PreventUpdate
structure = self.from_data(data)
if isinstance(structure, StructureGraph):
structure = structure.structure
file_prefix = structure.composition.reduced_formula
if "VASP" not in download_option:
extension = self.download_options["Structure"][download_option]["fmt"]
options = self.download_options["Structure"][download_option]
try:
contents = structure.to(**options)
except Exception as exc:
# don't fail silently, tell user what went wrong
contents = exc
base64 = b64encode(contents.encode("utf-8")).decode("ascii")
download_data = {
"content": base64,
"base64": True,
"type": "text/plain",
"filename": f"{file_prefix}.{extension}",
}
else:
if "Relax" in download_option:
vis = MPRelaxSet(structure)
expected_filename = "MPRelaxSet.zip"
else:
raise ValueError("No other VASP input sets currently supported.")
with TemporaryDirectory() as tmpdir:
vis.write_input(tmpdir, potcar_spec=True, zip_output=True)
path = Path(tmpdir) / expected_filename
bytes = b64encode(path.read_bytes()).decode("ascii")
download_data = {
"content": bytes,
"base64": True,
"type": "application/zip",
"filename": f"{file_prefix} {expected_filename}",
}
return download_data
@app.callback(
Output(self.id("title_container"), "children"),
[Input(self.id("legend_data"), "data")],
)
@cache.memoize()
def update_title(legend):
if not legend:
raise PreventUpdate
legend = self.from_data(legend)
return self._make_title(legend)
@app.callback(
Output(self.id("legend_container"), "children"),
[Input(self.id("legend_data"), "data")],
)
@cache.memoize()
def update_legend(legend):
if not legend:
raise PreventUpdate
legend = self.from_data(legend)
return self._make_legend(legend)
@app.callback(
[
Output(self.id("bonding_algorithm_custom_cutoffs"), "data"),
Output(self.id("bonding_algorithm_custom_cutoffs_container"), "style"),
],
[Input(self.id("bonding_algorithm"), "value")],
[
State(self.id("graph"), "data"),
State(self.id("bonding_algorithm_custom_cutoffs_container"), "style"),
],
)
@cache.memoize()
def update_custom_bond_options(bonding_algorithm, graph, current_style):
if not graph:
raise PreventUpdate
if bonding_algorithm == "CutOffDictNN":
style = {}
else:
style = {"display": "none"}
if style == current_style:
# no need to update rows if we're not showing them
raise PreventUpdate
graph = self.from_data(graph)
rows = self._make_bonding_algorithm_custom_cuffoff_data(graph)
return rows, style
def _make_legend(self, legend):
if not legend:
return html.Div(id=self.id("legend"))
def get_font_color(hex_code):
# ensures contrasting font color for background color
c = tuple(int(hex_code[1:][i : i + 2], 16) for i in (0, 2, 4))
if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5:
font_color = "#000000"
else:
font_color = "#ffffff"
return font_color
try:
formula = Composition.from_dict(legend["composition"]).reduced_formula
except Exception:
# TODO: fix legend for Dummy Specie compositions
formula = "Unknown"
legend_colors = OrderedDict(
sorted(list(legend["colors"].items()), key=lambda x: formula.find(x[1]))
)
legend_elements = [
html.Span(
html.Span(
name, className="icon", style={"color": get_font_color(color)}
),
className="button is-static is-rounded",
style={"backgroundColor": color},
)
for color, name in legend_colors.items()
]
return html.Div(
legend_elements,
id=self.id("legend"),
style={"display": "flex"},
className="buttons",
)
def _make_title(self, legend):
if not legend or (not legend.get("composition", None)):
return H2(self.default_title, id=self.id("title"))
composition = legend["composition"]
if isinstance(composition, dict):
try:
composition = Composition.from_dict(composition)
# strip DummySpecie if present (TODO: should be method in pymatgen)
composition = Composition(
{
el: amt
for el, amt in composition.items()
if not isinstance(el, DummySpecie)
}
)
composition = composition.get_reduced_composition_and_factor()[0]
formula = composition.reduced_formula
formula_parts = re.findall(r"[^\d_]+|\d+", formula)
formula_components = [
html.Sub(part.strip())
if part.isnumeric()
else html.Span(part.strip())
for part in formula_parts
]
except Exception:
formula_components = list(map(str, composition))
return H2(
formula_components, id=self.id("title"), style={"display": "inline-block"}
)
@staticmethod
def _make_bonding_algorithm_custom_cuffoff_data(graph):
if not graph:
return [{"A": None, "B": None, "A—B": None}]
struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph)
# can't use type_of_specie because it doesn't work with disordered structures
species = set(
map(
str,
chain.from_iterable([list(c) for c in struct_or_mol.species_and_occu]),
)
)
rows = [
{"A": combination[0], "B": combination[1], "A—B": 0}
for combination in combinations_with_replacement(species, 2)
]
return rows
@property
def _sub_layouts(self):
title_layout = html.Div(
self._make_title(self._initial_data["legend_data"]),
id=self.id("title_container"),
)
nn_mapping = {
"CrystalNN": "CrystalNN",
"Custom Bonds": "CutOffDictNN",
"Jmol Bonding": "JmolNN",
"Minimum Distance (10% tolerance)": "MinimumDistanceNN",
"O'Keeffe's Algorithm": "MinimumOKeeffeNN",
"Hoppe's ECoN Algorithm": "EconNN",
"Brunner's Reciprocal Algorithm": "BrunnerNN_reciprocal",
}
bonding_algorithm = dcc.Dropdown(
options=[{"label": k, "value": v} for k, v in nn_mapping.items()],
value=self.initial_data["graph_generation_options"]["bonding_strategy"],
clearable=False,
id=self.id("bonding_algorithm"),
persistence=SETTINGS.PERSISTENCE,
persistence_type=SETTINGS.PERSISTENCE_TYPE,
)
bonding_algorithm_custom_cutoffs = html.Div(
[
html.Br(),
dt.DataTable(
columns=[
{"name": "A", "id": "A"},
{"name": "B", "id": "B"},
{"name": "A—B /Å", "id": "A—B"},
],
editable=True,
data=self._make_bonding_algorithm_custom_cuffoff_data(
self.initial_data.get("default")
),
id=self.id("bonding_algorithm_custom_cutoffs"),
),
html.Br(),
],
id=self.id("bonding_algorithm_custom_cutoffs_container"),
style={"display": "none"},
)
if self.show_settings:
options_layout = Field(
[
# TODO: hide if molecule
html.Label("Change unit cell:", className="mpc-label"),
html.Div(
dcc.Dropdown(
options=[
{"label": "Input cell", "value": "input"},
{"label": "Primitive cell", "value": "primitive"},
{"label": "Conventional cell", "value": "conventional"},
{
"label": "Reduced cell (Niggli)",
"value": "reduced_niggli",
},
{"label": "Reduced cell (LLL)", "value": "reduced_lll"},
],
value="input",
clearable=False,
id=self.id("unit-cell-choice"),
persistence=SETTINGS.PERSISTENCE,
persistence_type=SETTINGS.PERSISTENCE_TYPE,
),
className="mpc-control",
),
html.Div(
[
html.Label(
"Change bonding algorithm: ", className="mpc-label"
),
bonding_algorithm,
bonding_algorithm_custom_cutoffs,
]
),
html.Label("Change color scheme:", className="mpc-label"),
html.Div(
dcc.Dropdown(
options=[
{"label": "VESTA", "value": "VESTA"},
{"label": "Jmol", "value": "Jmol"},
{"label": "Accessible", "value": "accessible"},
],
value=self.initial_data["display_options"]["color_scheme"],
clearable=False,
persistence=SETTINGS.PERSISTENCE,
persistence_type=SETTINGS.PERSISTENCE_TYPE,
id=self.id("color-scheme"),
),
className="mpc-control",
),
html.Label("Change atomic radii:", className="mpc-label"),
html.Div(
dcc.Dropdown(
options=[
{
"label": "Ionic",
"value": "specified_or_average_ionic",
},
{"label": "Covalent", "value": "covalent"},
{"label": "Van der Waals", "value": "van_der_waals"},
{
"label": f"Uniform ({Legend.uniform_radius}Å)",
"value": "uniform",
},
],
value=self.initial_data["display_options"][
"radius_strategy"
],
clearable=False,
persistence=SETTINGS.PERSISTENCE,
persistence_type=SETTINGS.PERSISTENCE_TYPE,
id=self.id("radius_strategy"),
),
className="mpc-control",
),
html.Label("Draw options:", className="mpc-label"),
html.Div(
[
dcc.Checklist(
options=[
{
"label": "Draw repeats of atoms on periodic boundaries",
"value": "draw_image_atoms",
},
{
"label": "Draw atoms outside unit cell bonded to "
"atoms within unit cell",
"value": "bonded_sites_outside_unit_cell",
},
{
"label": "Hide bonds where destination atoms are not shown",
"value": "hide_incomplete_bonds",
},
],
value=[
opt
for opt in (
"draw_image_atoms",
"bonded_sites_outside_unit_cell",
"hide_incomplete_bonds",
)
if self.initial_data["display_options"][opt]
],
labelStyle={"display": "block"},
inputClassName="mpc-radio",
id=self.id("draw_options"),
persistence=SETTINGS.PERSISTENCE,
persistence_type=SETTINGS.PERSISTENCE_TYPE,
)
]
),
html.Label("Hide/show:", className="mpc-label"),
html.Div(
[
dcc.Checklist(
options=[
{"label": "Atoms", "value": "atoms"},
{"label": "Bonds", "value": "bonds"},
{"label": "Unit cell", "value": "unit_cell"},
{"label": "Polyhedra", "value": "polyhedra"},
{"label": "Axes", "value": "axes"},
],
value=["atoms", "bonds", "unit_cell", "polyhedra"],
labelStyle={"display": "block"},
inputClassName="mpc-radio",
id=self.id("hide-show"),
persistence=SETTINGS.PERSISTENCE,
persistence_type=SETTINGS.PERSISTENCE_TYPE,
)
],
className="mpc-control",
),
]
)
else:
options_layout = None
if self.show_legend:
legend_layout = html.Div(
self._make_legend(self._initial_data["legend_data"]),
id=self.id("legend_container"),
)
else:
legend_layout = None
struct_layout = html.Div(
[
CrystalToolkitScene(
[options_layout, legend_layout],
id=self.id("scene"),
className=self.className,
data=self.initial_data["scene"],
settings=self.initial_scene_settings,
sceneSize="100%",
fileOptions=list(self.download_options["Structure"]),
showControls=self.show_controls,
showExpandButton=self.show_expand_button,
showImageButton=self.show_image_button,
showExportButton=self.show_export_button,
showPositionButton=self.show_position_button,
**self.scene_kwargs,
),
dcc.Download(id=self.id("download-image")),
dcc.Download(id=self.id("download-structure")),
]
)
return {
"struct": struct_layout,
"options": options_layout,
"title": title_layout,
"legend": legend_layout,
}
[docs] def layout(self, size: str = "500px") -> html.Div:
"""
:param size: a CSS dimension specifying width/height of Div
:return: A html.Div containing the 3D structure or molecule
"""
return html.Div(
self._sub_layouts["struct"], style={"width": size, "height": size}
)
@staticmethod
def _preprocess_structure(
struct_or_mol: Structure | StructureGraph | Molecule | MoleculeGraph,
unit_cell_choice: Literal[
"input", "primitive", "conventional", "reduced_niggli", "reduced_lll"
] = "input",
):
if isinstance(struct_or_mol, Structure):
if unit_cell_choice != "input":
if unit_cell_choice == "primitive":
struct_or_mol = struct_or_mol.get_primitive_structure()
elif unit_cell_choice == "conventional":
sga = SpacegroupAnalyzer(struct_or_mol)
struct_or_mol = sga.get_conventional_standard_structure()
elif unit_cell_choice == "reduced_niggli":
struct_or_mol = struct_or_mol.get_reduced_structure(
reduction_algo="niggli"
)
elif unit_cell_choice == "reduced_lll":
struct_or_mol = struct_or_mol.get_reduced_structure(
reduction_algo="LLL"
)
return struct_or_mol
@staticmethod
def _preprocess_input_to_graph(
input: Structure | StructureGraph | Molecule | MoleculeGraph,
bonding_strategy: str = DEFAULTS["bonding_strategy"],
bonding_strategy_kwargs: dict | None = None,
) -> StructureGraph | MoleculeGraph:
if isinstance(input, Structure):
# ensure fractional coordinates are normalized to be in [0,1)
# (this is actually not guaranteed by Structure)
try:
input = input.as_dict(verbosity=0)
except TypeError:
# TODO: remove this, necessary for Slab(?), some structure subclasses don't have verbosity
input = input.as_dict()
for site in input["sites"]:
site["abc"] = np.mod(site["abc"], 1)
input = Structure.from_dict(input)
if not input.is_ordered:
# calculating bonds in disordered structures is currently very flaky
bonding_strategy = "CutOffDictNN"
# we assume most uses of this class will give a structure as an input argument,
# meaning we have to calculate the graph for bonding information, however if
# the graph is already known and supplied, we will use that
if isinstance(input, StructureGraph) or isinstance(input, MoleculeGraph):
graph = input
else:
if (
bonding_strategy
not in StructureMoleculeComponent.available_bonding_strategies
):
valid_subclasses = ", ".join(
StructureMoleculeComponent.available_bonding_strategies
)
raise ValueError(
"Bonding strategy not supported. Please supply a name of a NearNeighbor "
f"subclass, choose from: {valid_subclasses}"
)
else:
bonding_strategy_kwargs = bonding_strategy_kwargs or {}
if bonding_strategy == "CutOffDictNN":
if "cut_off_dict" in bonding_strategy_kwargs:
# TODO: remove this hack by making args properly JSON serializable
bonding_strategy_kwargs["cut_off_dict"] = {
(x[0], x[1]): x[2]
for x in bonding_strategy_kwargs["cut_off_dict"]
}
bonding_strategy = (
StructureMoleculeComponent.available_bonding_strategies[
bonding_strategy
](**bonding_strategy_kwargs)
)
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if isinstance(input, Structure):
graph = StructureGraph.with_local_env_strategy(
input, bonding_strategy
)
else:
graph = MoleculeGraph.with_local_env_strategy(
input, bonding_strategy, reorder=False
)
except Exception:
# for some reason computing bonds failed, so let's not have any bonds(!)
if isinstance(input, Structure):
graph = StructureGraph.with_empty_graph(input)
else:
graph = MoleculeGraph.with_empty_graph(input)
return graph
@staticmethod
def _get_struct_or_mol(
graph: StructureGraph | MoleculeGraph | Structure | Molecule,
) -> Structure | Molecule:
if isinstance(graph, StructureGraph):
return graph.structure
elif isinstance(graph, MoleculeGraph):
return graph.molecule
elif isinstance(graph, Structure) or isinstance(graph, Molecule):
return graph
else:
raise ValueError
[docs] @staticmethod
def get_scene_and_legend(
graph: StructureGraph | MoleculeGraph | None,
color_scheme=DEFAULTS["color_scheme"],
color_scale=None,
radius_strategy=DEFAULTS["radius_strategy"],
draw_image_atoms=DEFAULTS["draw_image_atoms"],
bonded_sites_outside_unit_cell=DEFAULTS["bonded_sites_outside_unit_cell"],
hide_incomplete_bonds=DEFAULTS["hide_incomplete_bonds"],
explicitly_calculate_polyhedra_hull=False,
scene_additions=None,
show_compass=DEFAULTS["show_compass"],
group_by_site_property=None,
) -> tuple[Scene, dict[str, str]]:
scene = Scene(name="StructureMoleculeComponentScene")
if graph is None:
return scene, {}
struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph)
# TODO: add radius_scale
legend = Legend(
struct_or_mol,
color_scheme=color_scheme,
radius_scheme=radius_strategy,
cmap_range=color_scale,
)
if isinstance(graph, StructureGraph):
scene = graph.get_scene(
draw_image_atoms=draw_image_atoms,
bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell,
hide_incomplete_edges=hide_incomplete_bonds,
explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull,
group_by_site_property=group_by_site_property,
legend=legend,
)
elif isinstance(graph, MoleculeGraph):
scene = graph.get_scene(legend=legend)
scene.name = "StructureMoleculeComponentScene"
if hasattr(struct_or_mol, "lattice"):
axes = struct_or_mol.lattice._axes_from_lattice()
axes.visible = show_compass
scene.contents.append(axes)
scene_json = scene.to_json()
if scene_additions:
# TODO: this might be cleaner if we had a Scene.from_json() method
scene_json["contents"].append(scene_additions)
return scene_json, legend.get_legend()
[docs] def title_layout(self):
"""
:return: A layout including the composition of the structure/molecule as a title.
"""
return self._sub_layouts["title"]