Source code for crystal_toolkit.components.structure

from __future__ import annotations

import re
import warnings
from base64 import b64encode
from collections import OrderedDict
from itertools import chain, combinations_with_replacement
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Literal

import numpy as np
from dash import dash_table as dt
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate
from dash_mp_components import CrystalToolkitScene
from emmet.core.settings import EmmetSettings
from pymatgen.analysis.graphs import MoleculeGraph, StructureGraph
from pymatgen.analysis.local_env import NearNeighbors
from pymatgen.core.composition import Composition
from pymatgen.core.periodic_table import DummySpecie
from pymatgen.core.structure import Molecule, Structure
from pymatgen.io.vasp.sets import MPRelaxSet
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

from crystal_toolkit.core.legend import Legend
from crystal_toolkit.core.mpcomponent import MPComponent
from crystal_toolkit.core.scene import Scene
from crystal_toolkit.helpers.layouts import H2, Field, dcc, html
from crystal_toolkit.settings import SETTINGS

# TODO: make dangling bonds "stubs"? (fixed length)

DEFAULTS: dict[str, str | bool] = {
    "color_scheme": "VESTA",
    "bonding_strategy": "CrystalNN",
    "radius_strategy": "uniform",
    "draw_image_atoms": True,
    "bonded_sites_outside_unit_cell": False,
    "hide_incomplete_bonds": True,
    "show_compass": True,
    "unit_cell_choice": "input",
    "show_legend": True,
    "show_settings": True,
    "show_controls": True,
    "show_expand_button": True,
    "show_image_button": True,
    "show_export_button": True,
    "show_position_button": True,
}


[docs]class StructureMoleculeComponent(MPComponent): """ A component to display pymatgen Structure, Molecule, StructureGraph and MoleculeGraph objects. """ available_bonding_strategies = { subclass.__name__: subclass for subclass in NearNeighbors.__subclasses__() } default_scene_settings = { "extractAxis": True, # For visual diff testing, we change the renderer # to SVG since this WebGL support is more difficult # in headless browsers / CI. "renderer": "svg" if SETTINGS.TEST_MODE else "webgl", "secondaryObjectView": False, } # what to show for the title_layout if structure/molecule not loaded default_title = "Crystal Toolkit" # human-readable label to file extension # downloading Molecules has not yet been added download_options = { "Structure": { "CIF (Symmetrized)": {"fmt": "cif", "symprec": EmmetSettings().SYMPREC}, "CIF": {"fmt": "cif"}, "POSCAR": {"fmt": "poscar"}, "JSON": {"fmt": "json"}, "Prismatic": {"fmt": "prismatic"}, "VASP Input Set (MPRelaxSet)": {}, # special } } def __init__( self, struct_or_mol: None | (Structure | StructureGraph | Molecule | MoleculeGraph) = None, id: str = None, className: str = "box", scene_additions: Scene | None = None, bonding_strategy: str = DEFAULTS["bonding_strategy"], bonding_strategy_kwargs: dict | None = None, color_scheme: str = DEFAULTS["color_scheme"], color_scale: str | None = None, radius_strategy: str = DEFAULTS["radius_strategy"], unit_cell_choice: str = DEFAULTS["unit_cell_choice"], draw_image_atoms: bool = DEFAULTS["draw_image_atoms"], bonded_sites_outside_unit_cell: bool = DEFAULTS[ "bonded_sites_outside_unit_cell" ], hide_incomplete_bonds: bool = DEFAULTS["hide_incomplete_bonds"], show_compass: bool = DEFAULTS["show_compass"], scene_settings: dict | None = None, group_by_site_property: str | None = None, show_legend: bool = DEFAULTS["show_legend"], show_settings: bool = DEFAULTS["show_settings"], show_controls: bool = DEFAULTS["show_controls"], show_expand_button: bool = DEFAULTS["show_expand_button"], show_image_button: bool = DEFAULTS["show_image_button"], show_export_button: bool = DEFAULTS["show_export_button"], show_position_button: bool = DEFAULTS["show_position_button"], **kwargs, ): """ Create a StructureMoleculeComponent from a structure or molecule. :param struct_or_mol: input structure or molecule :param id: canonical id :param scene_additions: extra geometric elements to add to the 3D scene :param bonding_strategy: bonding strategy from pymatgen NearNeighbors class :param bonding_strategy_kwargs: options for the bonding strategy :param color_scheme: color scheme, see Legend class :param color_scale: color scale, see Legend class :param radius_strategy: radius strategy, see Legend class :param draw_image_atoms: whether to draw repeats of atoms on periodic images :param bonded_sites_outside_unit_cell: whether to draw sites bonded outside the unit cell :param hide_incomplete_bonds: whether to hide or show incomplete bonds :param show_compass: whether to hide or show the compass :param scene_settings: scene settings (lighting etc.) to pass to CrystalToolkitScene :param group_by_site_property: a site property used for grouping of atoms for mouseover/interaction, :param show_legend: show or hide legend panel within the scene :param show_controls: show or hide scene control bar :param show_expand_button: show or hide the full screen button within the scene control bar :param show_image_button: show or hide the image download button within the scene control bar :param show_export_button: show or hide the file export button within the scene control bar :param show_position_button: show or hide the revert position button within the scene control bar e.g. Wyckoff label :param kwargs: extra keyword arguments to pass to MPComponent """ super().__init__(id=id, default_data=struct_or_mol, **kwargs) self.className = className self.show_legend = show_legend self.show_settings = show_settings self.show_controls = show_controls self.show_expand_button = show_expand_button self.show_image_button = show_image_button self.show_export_button = show_export_button self.show_position_button = show_position_button self.initial_scene_settings = self.default_scene_settings.copy() if scene_settings: self.initial_scene_settings.update(scene_settings) self.create_store("scene_settings", initial_data=self.initial_scene_settings) # unit cell choice and bonding algorithms need to come from a settings # object (in a dcc.Store) guaranteed to be present in layout, rather # than from the controls themselves -- since these are optional and # may not be present in the layout self.create_store( "graph_generation_options", initial_data={ "bonding_strategy": bonding_strategy, "bonding_strategy_kwargs": bonding_strategy_kwargs, "unit_cell_choice": unit_cell_choice, }, ) self.create_store( "display_options", initial_data={ "color_scheme": color_scheme, "color_scale": color_scale, "radius_strategy": radius_strategy, "draw_image_atoms": draw_image_atoms, "bonded_sites_outside_unit_cell": bonded_sites_outside_unit_cell, "hide_incomplete_bonds": hide_incomplete_bonds, "show_compass": show_compass, "group_by_site_property": group_by_site_property, }, ) if scene_additions: initial_scene_additions = Scene( name="scene_additions", contents=scene_additions ).to_json() else: initial_scene_additions = None self.create_store("scene_additions", initial_data=initial_scene_additions) if struct_or_mol: # graph is cached explicitly, this isn't necessary but is an # optimization so that graph is only re-generated if bonding # algorithm changes struct_or_mol = self._preprocess_structure( struct_or_mol, unit_cell_choice=unit_cell_choice ) graph = self._preprocess_input_to_graph( struct_or_mol, bonding_strategy=bonding_strategy, bonding_strategy_kwargs=bonding_strategy_kwargs, ) scene, legend = self.get_scene_and_legend( graph, scene_additions=self.initial_data["scene_additions"], **self.initial_data["display_options"], ) if hasattr(struct_or_mol, "lattice"): self._lattice = struct_or_mol.lattice else: # component could be initialized without a structure, in which case # an empty scene should be displayed graph = None scene, legend = self.get_scene_and_legend( None, scene_additions=self.initial_data["scene_additions"], **self.initial_data["display_options"], ) self.create_store("legend_data", initial_data=legend) self.create_store("graph", initial_data=graph) # this is used by a CrystalToolkitScene component, not a dcc.Store self._initial_data["scene"] = scene # hide axes inset for molecules if isinstance(struct_or_mol, Molecule) or isinstance( struct_or_mol, MoleculeGraph ): self.scene_kwargs = {"axisView": "HIDDEN"} else: self.scene_kwargs = {}
[docs] def generate_callbacks(self, app, cache): # a lot of the verbosity in this callback is to support custom bonding # this is not the format CutOffDictNN expects (since that is not JSON # serializable), so we store as a list of tuples instead # TODO: make CutOffDictNN args JSON serializable app.clientside_callback( """ function (bonding_strategy, custom_cutoffs_rows, unit_cell_choice) { const bonding_strategy_kwargs = {} if (bonding_strategy === 'CutOffDictNN') { const cut_off_dict = [] custom_cutoffs_rows.forEach(function(row) { cut_off_dict.push([row['A'], row['B'], parseFloat(row['A—B'])]) }) bonding_strategy_kwargs.cut_off_dict = cut_off_dict } return { bonding_strategy: bonding_strategy, bonding_strategy_kwargs: bonding_strategy_kwargs, unit_cell_choice: unit_cell_choice } } """, Output(self.id("graph_generation_options"), "data"), [ Input(self.id("bonding_algorithm"), "value"), Input(self.id("bonding_algorithm_custom_cutoffs"), "data"), Input(self.id("unit-cell-choice"), "value"), ], ) app.clientside_callback( """ function (values, options) { const visibility = {} options.forEach(function (opt) { visibility[opt.value] = Boolean(values.includes(opt.value)) }) return visibility } """, Output(self.id("scene"), "toggleVisibility"), [Input(self.id("hide-show"), "value")], [State(self.id("hide-show"), "options")], ) app.clientside_callback( """ function (colorScheme, radiusStrategy, drawOptions, displayOptions) { const newDisplayOptions = {...displayOptions} newDisplayOptions.color_scheme = colorScheme newDisplayOptions.radius_strategy = radiusStrategy newDisplayOptions.draw_image_atoms = drawOptions.includes('draw_image_atoms') newDisplayOptions.bonded_sites_outside_unit_cell = drawOptions.includes( 'bonded_sites_outside_unit_cell' ) newDisplayOptions.hide_incomplete_bonds = drawOptions.includes('hide_incomplete_bonds') return newDisplayOptions } """, Output(self.id("display_options"), "data"), [ Input(self.id("color-scheme"), "value"), Input(self.id("radius_strategy"), "value"), Input(self.id("draw_options"), "value"), ], [State(self.id("display_options"), "data")], ) @app.callback( Output(self.id("graph"), "data"), [ Input(self.id("graph_generation_options"), "data"), Input(self.id(), "data"), ], [State(self.id("graph"), "data")], ) @cache.memoize() def update_graph(graph_generation_options, struct_or_mol, current_graph): if not struct_or_mol: raise PreventUpdate struct_or_mol = self.from_data(struct_or_mol) current_graph = self.from_data(current_graph) bonding_strategy_kwargs = graph_generation_options[ "bonding_strategy_kwargs" ] # TODO: add additional check here? unit_cell_choice = graph_generation_options["unit_cell_choice"] struct_or_mol = self._preprocess_structure(struct_or_mol, unit_cell_choice) graph = self._preprocess_input_to_graph( struct_or_mol, bonding_strategy=graph_generation_options["bonding_strategy"], bonding_strategy_kwargs=bonding_strategy_kwargs, ) if ( current_graph and graph.structure == current_graph.structure and graph == current_graph ): raise PreventUpdate return graph @app.callback( Output(self.id("scene"), "data"), [ Input(self.id("graph"), "data"), Input(self.id("display_options"), "data"), Input(self.id("scene_additions"), "data"), ], ) @cache.memoize() def update_scene(graph, display_options, scene_additions): if not graph or not display_options: raise PreventUpdate display_options = self.from_data(display_options) graph = self.from_data(graph) scene, legend = self.get_scene_and_legend( graph, **display_options, scene_additions=scene_additions ) return scene @app.callback( Output(self.id("legend_data"), "data"), [ Input(self.id("graph"), "data"), Input(self.id("display_options"), "data"), Input(self.id("scene_additions"), "data"), ], ) @cache.memoize() def update_legend_and_colors(graph, display_options, scene_additions): if not graph or not display_options: raise PreventUpdate display_options = self.from_data(display_options) graph = self.from_data(graph) scene, legend = self.get_scene_and_legend( graph, **display_options, scene_additions=scene_additions ) return legend @app.callback( Output(self.id("color-scheme"), "options"), [Input(self.id("legend_data"), "data")], ) def update_color_options(legend_data): # TODO: make client-side color_options = [ {"label": "Jmol", "value": "Jmol"}, {"label": "VESTA", "value": "VESTA"}, {"label": "Accessible", "value": "accessible"}, ] if not legend_data: return color_options for option in legend_data["available_color_schemes"]: color_options += [ {"label": f"Site property: {option}", "value": option} ] return color_options # app.clientside_callback( # """ # function (legendData) { # # var colorOptions = [ # {label: "Jmol", value: "Jmol"}, # {label: "VESTA", value: "VESTA"}, # {label: "Accessible", value: "accessible"}, # ] # # # # return colorOptions # } # """, # Output(self.id("color-scheme"), "options"), # [Input(self.id("legend_data"), "data")] # ) @app.callback( Output(self.id("download-image"), "data"), Input(self.id("scene"), "imageDataTimestamp"), [ State(self.id("scene"), "imageData"), State(self.id(), "data"), ], ) def download_image(image_data_timestamp, image_data, data): if not image_data_timestamp: raise PreventUpdate struct_or_mol = self.from_data(data) if isinstance(struct_or_mol, StructureGraph): formula = struct_or_mol.structure.composition.reduced_formula elif isinstance(struct_or_mol, MoleculeGraph): formula = struct_or_mol.molecule.composition.reduced_Formula else: formula = struct_or_mol.composition.reduced_formula if hasattr(struct_or_mol, "get_space_group_info"): spgrp = struct_or_mol.get_space_group_info()[0] else: spgrp = "" request_filename = f"{formula}-{spgrp}-crystal-toolkit.png" return { "content": image_data[len("data:image/png;base64,") :], "filename": request_filename, "base64": True, "type": "image/png", } @app.callback( Output(self.id("download-structure"), "data"), Input(self.id("scene"), "fileTimestamp"), [ State(self.id("scene"), "fileType"), State(self.id(), "data"), ], ) def download_structure(file_timestamp, download_option, data): if not file_timestamp: raise PreventUpdate structure = self.from_data(data) if isinstance(structure, StructureGraph): structure = structure.structure file_prefix = structure.composition.reduced_formula if "VASP" not in download_option: extension = self.download_options["Structure"][download_option]["fmt"] options = self.download_options["Structure"][download_option] try: contents = structure.to(**options) except Exception as exc: # don't fail silently, tell user what went wrong contents = exc base64 = b64encode(contents.encode("utf-8")).decode("ascii") download_data = { "content": base64, "base64": True, "type": "text/plain", "filename": f"{file_prefix}.{extension}", } else: if "Relax" in download_option: vis = MPRelaxSet(structure) expected_filename = "MPRelaxSet.zip" else: raise ValueError("No other VASP input sets currently supported.") with TemporaryDirectory() as tmpdir: vis.write_input(tmpdir, potcar_spec=True, zip_output=True) path = Path(tmpdir) / expected_filename bytes = b64encode(path.read_bytes()).decode("ascii") download_data = { "content": bytes, "base64": True, "type": "application/zip", "filename": f"{file_prefix} {expected_filename}", } return download_data @app.callback( Output(self.id("title_container"), "children"), [Input(self.id("legend_data"), "data")], ) @cache.memoize() def update_title(legend): if not legend: raise PreventUpdate legend = self.from_data(legend) return self._make_title(legend) @app.callback( Output(self.id("legend_container"), "children"), [Input(self.id("legend_data"), "data")], ) @cache.memoize() def update_legend(legend): if not legend: raise PreventUpdate legend = self.from_data(legend) return self._make_legend(legend) @app.callback( [ Output(self.id("bonding_algorithm_custom_cutoffs"), "data"), Output(self.id("bonding_algorithm_custom_cutoffs_container"), "style"), ], [Input(self.id("bonding_algorithm"), "value")], [ State(self.id("graph"), "data"), State(self.id("bonding_algorithm_custom_cutoffs_container"), "style"), ], ) @cache.memoize() def update_custom_bond_options(bonding_algorithm, graph, current_style): if not graph: raise PreventUpdate if bonding_algorithm == "CutOffDictNN": style = {} else: style = {"display": "none"} if style == current_style: # no need to update rows if we're not showing them raise PreventUpdate graph = self.from_data(graph) rows = self._make_bonding_algorithm_custom_cuffoff_data(graph) return rows, style
def _make_legend(self, legend): if not legend: return html.Div(id=self.id("legend")) def get_font_color(hex_code): # ensures contrasting font color for background color c = tuple(int(hex_code[1:][i : i + 2], 16) for i in (0, 2, 4)) if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5: font_color = "#000000" else: font_color = "#ffffff" return font_color try: formula = Composition.from_dict(legend["composition"]).reduced_formula except Exception: # TODO: fix legend for Dummy Specie compositions formula = "Unknown" legend_colors = OrderedDict( sorted(list(legend["colors"].items()), key=lambda x: formula.find(x[1])) ) legend_elements = [ html.Span( html.Span( name, className="icon", style={"color": get_font_color(color)} ), className="button is-static is-rounded", style={"backgroundColor": color}, ) for color, name in legend_colors.items() ] return html.Div( legend_elements, id=self.id("legend"), style={"display": "flex"}, className="buttons", ) def _make_title(self, legend): if not legend or (not legend.get("composition", None)): return H2(self.default_title, id=self.id("title")) composition = legend["composition"] if isinstance(composition, dict): try: composition = Composition.from_dict(composition) # strip DummySpecie if present (TODO: should be method in pymatgen) composition = Composition( { el: amt for el, amt in composition.items() if not isinstance(el, DummySpecie) } ) composition = composition.get_reduced_composition_and_factor()[0] formula = composition.reduced_formula formula_parts = re.findall(r"[^\d_]+|\d+", formula) formula_components = [ html.Sub(part.strip()) if part.isnumeric() else html.Span(part.strip()) for part in formula_parts ] except Exception: formula_components = list(map(str, composition)) return H2( formula_components, id=self.id("title"), style={"display": "inline-block"} ) @staticmethod def _make_bonding_algorithm_custom_cuffoff_data(graph): if not graph: return [{"A": None, "B": None, "A—B": None}] struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph) # can't use type_of_specie because it doesn't work with disordered structures species = set( map( str, chain.from_iterable([list(c) for c in struct_or_mol.species_and_occu]), ) ) rows = [ {"A": combination[0], "B": combination[1], "A—B": 0} for combination in combinations_with_replacement(species, 2) ] return rows @property def _sub_layouts(self): title_layout = html.Div( self._make_title(self._initial_data["legend_data"]), id=self.id("title_container"), ) nn_mapping = { "CrystalNN": "CrystalNN", "Custom Bonds": "CutOffDictNN", "Jmol Bonding": "JmolNN", "Minimum Distance (10% tolerance)": "MinimumDistanceNN", "O'Keeffe's Algorithm": "MinimumOKeeffeNN", "Hoppe's ECoN Algorithm": "EconNN", "Brunner's Reciprocal Algorithm": "BrunnerNN_reciprocal", } bonding_algorithm = dcc.Dropdown( options=[{"label": k, "value": v} for k, v in nn_mapping.items()], value=self.initial_data["graph_generation_options"]["bonding_strategy"], clearable=False, id=self.id("bonding_algorithm"), persistence=SETTINGS.PERSISTENCE, persistence_type=SETTINGS.PERSISTENCE_TYPE, ) bonding_algorithm_custom_cutoffs = html.Div( [ html.Br(), dt.DataTable( columns=[ {"name": "A", "id": "A"}, {"name": "B", "id": "B"}, {"name": "A—B /Å", "id": "A—B"}, ], editable=True, data=self._make_bonding_algorithm_custom_cuffoff_data( self.initial_data.get("default") ), id=self.id("bonding_algorithm_custom_cutoffs"), ), html.Br(), ], id=self.id("bonding_algorithm_custom_cutoffs_container"), style={"display": "none"}, ) if self.show_settings: options_layout = Field( [ #  TODO: hide if molecule html.Label("Change unit cell:", className="mpc-label"), html.Div( dcc.Dropdown( options=[ {"label": "Input cell", "value": "input"}, {"label": "Primitive cell", "value": "primitive"}, {"label": "Conventional cell", "value": "conventional"}, { "label": "Reduced cell (Niggli)", "value": "reduced_niggli", }, {"label": "Reduced cell (LLL)", "value": "reduced_lll"}, ], value="input", clearable=False, id=self.id("unit-cell-choice"), persistence=SETTINGS.PERSISTENCE, persistence_type=SETTINGS.PERSISTENCE_TYPE, ), className="mpc-control", ), html.Div( [ html.Label( "Change bonding algorithm: ", className="mpc-label" ), bonding_algorithm, bonding_algorithm_custom_cutoffs, ] ), html.Label("Change color scheme:", className="mpc-label"), html.Div( dcc.Dropdown( options=[ {"label": "VESTA", "value": "VESTA"}, {"label": "Jmol", "value": "Jmol"}, {"label": "Accessible", "value": "accessible"}, ], value=self.initial_data["display_options"]["color_scheme"], clearable=False, persistence=SETTINGS.PERSISTENCE, persistence_type=SETTINGS.PERSISTENCE_TYPE, id=self.id("color-scheme"), ), className="mpc-control", ), html.Label("Change atomic radii:", className="mpc-label"), html.Div( dcc.Dropdown( options=[ { "label": "Ionic", "value": "specified_or_average_ionic", }, {"label": "Covalent", "value": "covalent"}, {"label": "Van der Waals", "value": "van_der_waals"}, { "label": f"Uniform ({Legend.uniform_radius}Å)", "value": "uniform", }, ], value=self.initial_data["display_options"][ "radius_strategy" ], clearable=False, persistence=SETTINGS.PERSISTENCE, persistence_type=SETTINGS.PERSISTENCE_TYPE, id=self.id("radius_strategy"), ), className="mpc-control", ), html.Label("Draw options:", className="mpc-label"), html.Div( [ dcc.Checklist( options=[ { "label": "Draw repeats of atoms on periodic boundaries", "value": "draw_image_atoms", }, { "label": "Draw atoms outside unit cell bonded to " "atoms within unit cell", "value": "bonded_sites_outside_unit_cell", }, { "label": "Hide bonds where destination atoms are not shown", "value": "hide_incomplete_bonds", }, ], value=[ opt for opt in ( "draw_image_atoms", "bonded_sites_outside_unit_cell", "hide_incomplete_bonds", ) if self.initial_data["display_options"][opt] ], labelStyle={"display": "block"}, inputClassName="mpc-radio", id=self.id("draw_options"), persistence=SETTINGS.PERSISTENCE, persistence_type=SETTINGS.PERSISTENCE_TYPE, ) ] ), html.Label("Hide/show:", className="mpc-label"), html.Div( [ dcc.Checklist( options=[ {"label": "Atoms", "value": "atoms"}, {"label": "Bonds", "value": "bonds"}, {"label": "Unit cell", "value": "unit_cell"}, {"label": "Polyhedra", "value": "polyhedra"}, {"label": "Axes", "value": "axes"}, ], value=["atoms", "bonds", "unit_cell", "polyhedra"], labelStyle={"display": "block"}, inputClassName="mpc-radio", id=self.id("hide-show"), persistence=SETTINGS.PERSISTENCE, persistence_type=SETTINGS.PERSISTENCE_TYPE, ) ], className="mpc-control", ), ] ) else: options_layout = None if self.show_legend: legend_layout = html.Div( self._make_legend(self._initial_data["legend_data"]), id=self.id("legend_container"), ) else: legend_layout = None struct_layout = html.Div( [ CrystalToolkitScene( [options_layout, legend_layout], id=self.id("scene"), className=self.className, data=self.initial_data["scene"], settings=self.initial_scene_settings, sceneSize="100%", fileOptions=list(self.download_options["Structure"]), showControls=self.show_controls, showExpandButton=self.show_expand_button, showImageButton=self.show_image_button, showExportButton=self.show_export_button, showPositionButton=self.show_position_button, **self.scene_kwargs, ), dcc.Download(id=self.id("download-image")), dcc.Download(id=self.id("download-structure")), ] ) return { "struct": struct_layout, "options": options_layout, "title": title_layout, "legend": legend_layout, }
[docs] def layout(self, size: str = "500px") -> html.Div: """ :param size: a CSS dimension specifying width/height of Div :return: A html.Div containing the 3D structure or molecule """ return html.Div( self._sub_layouts["struct"], style={"width": size, "height": size} )
@staticmethod def _preprocess_structure( struct_or_mol: Structure | StructureGraph | Molecule | MoleculeGraph, unit_cell_choice: Literal[ "input", "primitive", "conventional", "reduced_niggli", "reduced_lll" ] = "input", ): if isinstance(struct_or_mol, Structure): if unit_cell_choice != "input": if unit_cell_choice == "primitive": struct_or_mol = struct_or_mol.get_primitive_structure() elif unit_cell_choice == "conventional": sga = SpacegroupAnalyzer(struct_or_mol) struct_or_mol = sga.get_conventional_standard_structure() elif unit_cell_choice == "reduced_niggli": struct_or_mol = struct_or_mol.get_reduced_structure( reduction_algo="niggli" ) elif unit_cell_choice == "reduced_lll": struct_or_mol = struct_or_mol.get_reduced_structure( reduction_algo="LLL" ) return struct_or_mol @staticmethod def _preprocess_input_to_graph( input: Structure | StructureGraph | Molecule | MoleculeGraph, bonding_strategy: str = DEFAULTS["bonding_strategy"], bonding_strategy_kwargs: dict | None = None, ) -> StructureGraph | MoleculeGraph: if isinstance(input, Structure): # ensure fractional coordinates are normalized to be in [0,1) # (this is actually not guaranteed by Structure) try: input = input.as_dict(verbosity=0) except TypeError: # TODO: remove this, necessary for Slab(?), some structure subclasses don't have verbosity input = input.as_dict() for site in input["sites"]: site["abc"] = np.mod(site["abc"], 1) input = Structure.from_dict(input) if not input.is_ordered: # calculating bonds in disordered structures is currently very flaky bonding_strategy = "CutOffDictNN" # we assume most uses of this class will give a structure as an input argument, # meaning we have to calculate the graph for bonding information, however if # the graph is already known and supplied, we will use that if isinstance(input, StructureGraph) or isinstance(input, MoleculeGraph): graph = input else: if ( bonding_strategy not in StructureMoleculeComponent.available_bonding_strategies ): valid_subclasses = ", ".join( StructureMoleculeComponent.available_bonding_strategies ) raise ValueError( "Bonding strategy not supported. Please supply a name of a NearNeighbor " f"subclass, choose from: {valid_subclasses}" ) else: bonding_strategy_kwargs = bonding_strategy_kwargs or {} if bonding_strategy == "CutOffDictNN": if "cut_off_dict" in bonding_strategy_kwargs: # TODO: remove this hack by making args properly JSON serializable bonding_strategy_kwargs["cut_off_dict"] = { (x[0], x[1]): x[2] for x in bonding_strategy_kwargs["cut_off_dict"] } bonding_strategy = ( StructureMoleculeComponent.available_bonding_strategies[ bonding_strategy ](**bonding_strategy_kwargs) ) try: with warnings.catch_warnings(): warnings.simplefilter("ignore") if isinstance(input, Structure): graph = StructureGraph.with_local_env_strategy( input, bonding_strategy ) else: graph = MoleculeGraph.with_local_env_strategy( input, bonding_strategy, reorder=False ) except Exception: # for some reason computing bonds failed, so let's not have any bonds(!) if isinstance(input, Structure): graph = StructureGraph.with_empty_graph(input) else: graph = MoleculeGraph.with_empty_graph(input) return graph @staticmethod def _get_struct_or_mol( graph: StructureGraph | MoleculeGraph | Structure | Molecule, ) -> Structure | Molecule: if isinstance(graph, StructureGraph): return graph.structure elif isinstance(graph, MoleculeGraph): return graph.molecule elif isinstance(graph, Structure) or isinstance(graph, Molecule): return graph else: raise ValueError
[docs] @staticmethod def get_scene_and_legend( graph: StructureGraph | MoleculeGraph | None, color_scheme=DEFAULTS["color_scheme"], color_scale=None, radius_strategy=DEFAULTS["radius_strategy"], draw_image_atoms=DEFAULTS["draw_image_atoms"], bonded_sites_outside_unit_cell=DEFAULTS["bonded_sites_outside_unit_cell"], hide_incomplete_bonds=DEFAULTS["hide_incomplete_bonds"], explicitly_calculate_polyhedra_hull=False, scene_additions=None, show_compass=DEFAULTS["show_compass"], group_by_site_property=None, ) -> tuple[Scene, dict[str, str]]: scene = Scene(name="StructureMoleculeComponentScene") if graph is None: return scene, {} struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph) # TODO: add radius_scale legend = Legend( struct_or_mol, color_scheme=color_scheme, radius_scheme=radius_strategy, cmap_range=color_scale, ) if isinstance(graph, StructureGraph): scene = graph.get_scene( draw_image_atoms=draw_image_atoms, bonded_sites_outside_unit_cell=bonded_sites_outside_unit_cell, hide_incomplete_edges=hide_incomplete_bonds, explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull, group_by_site_property=group_by_site_property, legend=legend, ) elif isinstance(graph, MoleculeGraph): scene = graph.get_scene(legend=legend) scene.name = "StructureMoleculeComponentScene" if hasattr(struct_or_mol, "lattice"): axes = struct_or_mol.lattice._axes_from_lattice() axes.visible = show_compass scene.contents.append(axes) scene_json = scene.to_json() if scene_additions: # TODO: this might be cleaner if we had a Scene.from_json() method scene_json["contents"].append(scene_additions) return scene_json, legend.get_legend()
[docs] def title_layout(self): """ :return: A layout including the composition of the structure/molecule as a title. """ return self._sub_layouts["title"]