Source code for crystal_toolkit.core.mpcomponent

from __future__ import annotations

import logging
from abc import ABC
from ast import literal_eval
from base64 import b64encode
from collections import defaultdict
from itertools import chain
from json import JSONDecodeError, dumps, loads
from typing import Any, Literal

import dash
import dash_mp_components as mpc
import numpy as np
from dash import dash_table as dt
from dash import dcc, html
from dash.dependencies import ALL
from flask_caching import Cache
from monty.json import MontyDecoder, MSONable

from crystal_toolkit import __version__ as ct_version
from crystal_toolkit.helpers.layouts import Button, Icon, Loading, add_label_help
from crystal_toolkit.settings import SETTINGS

# fallback cache if Redis etc. isn't set up
null_cache = Cache(config={"CACHE_TYPE": "null"})

# Crystal Toolkit namespace, added to the start of all ids
# so we can see which layouts have been added by Crystal Toolkit
CT_NAMESPACE = "CT"


[docs]class MPComponent(ABC): """ The abstract base class for an MPComponent. MPComponent is designed to help render an MSONable object. """ # reference to global Dash app app = None # reference to Flask cache cache = None # used to track all dcc.Stores required for all MPComponents to work # keyed by the MPComponent id _app_stores_dict: dict[str, list[dcc.Store]] = defaultdict(list) # used to track what individual Dash components are defined # by this MPComponent _all_id_basenames: set[str] = set() # used to defer generation of callbacks until app.layout defined # can be helpful to callback exceptions retained _callbacks_to_generate: set[MPComponent] = set()
[docs] @staticmethod def register_app(app: dash.Dash): """ This method must be called at least once in your Crystal Toolkit Dash app if you want to enable interactivity with the MPComponents. The "app" variable is a special global variable used by Dash/Flask, and registering it with MPComponent allows callbacks to be registered with the app on instantiation. Args: app: a Dash app instance """ MPComponent.app = app # add metadata app.config.meta_tags.append( { "name": "generator", "content": f"Crystal Toolkit {ct_version} (Materials Project)", } ) # set default title, but respect the user if they override it if app.title == "Dash": app.title = "Crystal Toolkit"
[docs] @staticmethod def register_cache(cache: Cache): """ This method must be called at least once in your Crystal Toolkit Dash app if you want to enable callback caching. Callback caching is one of the easiest ways to see significant performance improvements, especially for callbacks that are computationally expensive. Args: cache: a flask_caching Cache instance """ if cache: MPComponent.cache = cache else: MPComponent.cache = Cache( MPComponent.app.server, config={"CACHE_TYPE": "simple"} )
[docs] @staticmethod def crystal_toolkit_layout(layout: html.Div) -> html.Div: if not MPComponent.app: raise ValueError( "Please register the Dash app with Crystal Toolkit " "using register_app()." ) # layout_str = str(layout) stores_to_add = [] for basename in MPComponent._all_id_basenames: # can use "if basename in layout_str:" to restrict to components present in initial layout # this would cause bugs for components displayed dynamically stores_to_add += MPComponent._app_stores_dict[basename] layout.children += stores_to_add # set app.layout to layout so that callbacks can be validated MPComponent.app.layout = layout for component in MPComponent._callbacks_to_generate: component.generate_callbacks(MPComponent.app, MPComponent.cache) return layout
[docs] @staticmethod def register_crystal_toolkit(app, layout, cache=None): MPComponent.register_app(app) MPComponent.register_cache(cache) app.config["suppress_callback_exceptions"] = True app.layout = MPComponent.crystal_toolkit_layout(layout)
[docs] @staticmethod def all_app_stores() -> html.Div: """ This must be included somewhere in your Crystal Toolkit Dash app's layout for interactivity to work. This is a hidden element that contains the MSON for each MPComponent. Returns: a html.Div Dash Layout """ return html.Div( list(chain.from_iterable(MPComponent._app_stores_dict.values())) )
def __init__( self, default_data: MSONable | dict | str | None = None, id: str | None = None, links: dict[str, str] | None = None, storage_type: Literal["memory", "local", "session"] = "memory", disable_callbacks: bool = False, ): """ The abstract base class for an MPComponent. The MPComponent is designed to help render any MSONable object, for example many of the objects in pymatgen (Structure, PhaseDiagram, etc.) To instantiate an MPComponent, you will need to create it outside of your Dash app layout: my_component = MPComponent(my_msonable_object) Then, inside the app.layout, you can include the component's layout anywhere you choose: my_component.layout If you want the layouts to be interactive, i.e. to respond to callbacks, you have to also use the MPComponent.register_app(app) method in your app, and also include MPComponent.all_app_stores in your app.layout (an invisible layout that contains the MSON itself). If you do not want the layouts to be interactive, set disable_callbacks to True to prevent errors. If including multiple MPComponents of the same type, make sure to set the id field to a unique value, as you would in any other Dash component. When sub-classing MPComponent, the most important methods to implement are _sub_layouts and generate_callbacks(). Args: default_data: initial contents for the component, can be None id: a unique id, required if multiple of the same type of MPComponent are included in an app links: if set, will set store contents from the stores of another component to reduce unnecessary callbacks and duplication of data, note that links are one directional only and specific the origin stores, e.g. set {"default": my_other_component.id()} to fill this component's default store contents from the other component's default store, or {"graph": my_other_component.id("graph")} to fill this component's "graph" store from another component's "graph" store storage_type: whether to persist contents of component through browser refresh or browser sessions, use with caution, defaults to "memory" so component store will be emptied on refresh, see dcc.Store documentation for more information disable_callbacks: if True, will not generate callbacks, useful for static layouts or returning new MPComponents dynamically where generating callbacks are not possible due to limitations of Dash """ # ensure ids are unique # Note: shadowing Python built-in here, but only because Dash does it... # TODO: do something else here if id is None: # TODO: this could lead to duplicate ids and an error, but if # setting random ids, this could also lead to undefined behavior id = f"{CT_NAMESPACE}{type(self).__name__}" elif not id.startswith(CT_NAMESPACE): id = f"{CT_NAMESPACE}{id}" MPComponent._all_id_basenames.add(id) self._id = id self._all_ids: set[str] = set() self._stores = {} self._initial_data = {} self.links = links or {} self.create_store( name="default", initial_data=default_data, storage_type=storage_type ) self.links["default"] = self.id() if not disable_callbacks: # callbacks generated as final step by crystal_toolkit_layout() self._callbacks_to_generate.add(self) self.logger = logging.getLogger(type(self).__name__)
[docs] def id( self, name: str = "default", is_kwarg: bool = False, idx=False, hint=None, is_store: bool = False, ) -> str | dict[str, str]: """ Generate an id from a name combined with the base id of the MPComponent itself, useful for generating ids of individual components in the layout. In the special case of the id of an element that is used to re-construct a keyword argument for a specific class, it will store information necessary to reconstruct that keyword argument (e.g. its type hint and, in the case of a vector or matrix, the corresponding index). A hint could be a tuple for a numpy array of that shape, e.g. (3, 3) for a 3x3 matrix, (1, 3) for a vector, or "literal" to parse kwarg value using ast.literal_eval, or "bool" to parse a boolean value. In future iterations, we may be able to replace this with native Python type hints. The problem here is being able to specify array shape where appropriate. Args: name: e.g. "default" Returns: e.g. "MPComponent_default" """ if name in self._stores: is_store = True if is_kwarg: return { "component_id": self._id, "kwarg_label": name, "idx": str(idx), "hint": str(hint), } # if we're linking to another component, return that id if name in self.links: return self.links[name] # otherwise create a new id self._all_ids.add(name) if name != "default": name = f"{self._id}_{name}" else: name = f"{self._id}" return name if is_store: return name else: return {"id": name}
[docs] def create_store( self, name: str, initial_data: MSONable | dict | str | None = None, storage_type: Literal["memory", "local", "session"] = "memory", debug_clear: bool = False, ): """ Generate a dcc.Store to hold something (MSONable object, Dict or string), and register it so that it will be included in the Dash app automatically. The initial data will be stored in a class attribute as self._initial_data[name]. Args: name: name for the store initial_data: initial data to include storage_type: as in dcc.Store debug_clear: set to True to empty the store if using persistent storage """ # if we're linking to another component, do not create a new store if name in self.links: return store = dcc.Store( id=self.id(name, is_store=True), data=initial_data, storage_type=storage_type, clear_data=debug_clear, ) self._stores[name] = store self._initial_data[name] = initial_data MPComponent._app_stores_dict[self.id()].append(store)
@property def initial_data(self): """ :return: Initial data for all the stores defined by component, keyed by store name. """ return self._initial_data
[docs] @staticmethod def from_data(data): """ Converts the contents of a dcc.Store back into a Python object. :param data: contents of a dcc.Store created by to_data :return: a Python object """ return loads(dumps(data), cls=MontyDecoder)
@property def all_stores(self) -> list[str]: """ :return: List of all store ids generated by this component """ return list(self._stores) @property def all_ids(self) -> list[str]: """ :return: List of all ids generated by this component """ return list( component_id for component_id in self._all_ids if component_id not in self.all_stores ) def __repr__(self): return f"{self.id()}<{type(self).__name__}>" def __str__(self): ids = "\n".join( [f"* {component_id} " for component_id in sorted(self.all_ids)] ) stores = "\n".join([f"* {store} " for store in sorted(self.all_stores)]) layouts = "\n".join([f"* {layout} " for layout in sorted(self._sub_layouts)]) return f"""{self.id()}<{type(self).__name__}> \n IDs: \n{ids} \n Stores: \n{stores} \n Sub-layouts: \n{layouts}""" @property def _sub_layouts(self): """ Layouts associated with this component, available for book-keeping if your component is complex, so that the layout() method is just assembles individual sub-layouts. :return: A dictionary with names of layouts as keys (str) and Dash layouts (e.g. html.Div) as values. """ return {}
[docs] def layout(self) -> html.Div: """ :return: A Dash layout for the full component. Basic implementation provided, but should in general be overridden. """ return html.Div(list(self._sub_layouts.values()))
[docs] def generate_callbacks(self, app, cache): """ Generate all callbacks associated with the layouts in this app. Assume that "suppress_callback_exceptions" is True, since it is not always guaranteed that all layouts will be displayed to the end user at all times, but it's important the callbacks are defined on the server. """ return None
[docs] def get_numerical_input( self, kwarg_label: str, default: int | float | list | None = None, state: dict | None = None, label: str | None = None, help_str: str = None, is_int: bool = False, shape: tuple[int, ...] = (), **kwargs, ): """ For Python classes which take matrices as inputs, this will generate a corresponding Dash input layout. :param kwarg_label: The name of the corresponding Python input, this is used to name the component. :param label: A description for this input. :param default: A default value for this input. :param state: Used to set default state for this input, use a dict with the kwarg_label as a key and the default value as a value. Ignored if `default` is set. It can be useful to use `state` if you want to set defaults for multiple inputs from a single dictionary. :param help_str: Text for a tooltip when hovering over label. :param is_int: if True, will use a numeric input :param shape: (3, 3) for matrix, (1, 3) for vector, (1, 1) for scalar :return: a Dash layout """ state = state or {} default = np.full(shape, default or state.get(kwarg_label)) default = np.reshape(default, shape) style = { "textAlign": "center", # shorter default width if matrix or vector "width": "5rem", "marginRight": "0.2rem", "marginBottom": "0.2rem", "height": "36px", } if "style" in kwargs: style.update(kwargs["style"]) del kwargs["style"] def matrix_element(idx, value=0): # TODO: maybe move element out of the name mid = self.id(kwarg_label, is_kwarg=True, idx=idx, hint=shape) if isinstance(value, np.ndarray): value = value.item() if not is_int: return dcc.Input( id=mid, inputMode="numeric", debounce=True, className="input", style=style, value=float(value) if value is not None else None, persistence=True, type="number", **kwargs, ) else: return dcc.Input( id=mid, inputMode="numeric", debounce=True, className="input", style=style, value=int(value) if value is not None else None, persistence=True, type="number", step=1, **kwargs, ) # dict of row indices, column indices to element matrix_contents = defaultdict(dict) # determine what individual input boxes we need # note that shape = () for floats, shape = (3,) for vectors # but we may also need to accept input for e.g. (3, 1) it = np.nditer(default, flags=["multi_index", "refs_ok"]) while not it.finished: idx = it.multi_index row = (idx[1] if len(idx) > 1 else 0,) column = idx[0] if len(idx) > 0 else 0 matrix_contents[row][column] = matrix_element(idx, value=it[0]) it.iternext() # arrange the input boxes in two dimensions (rows, columns) matrix_div_contents = [] print("matrix_contents", matrix_contents) for column_idx in sorted(matrix_contents): row = [] for row_idx in sorted(matrix_contents[column_idx]): row.append(matrix_contents[column_idx][row_idx]) matrix_div_contents.append(html.Div(row)) matrix = html.Div(matrix_div_contents) return add_label_help(matrix, label, help_str)
[docs] def get_slider_input( self, kwarg_label: str, default: Any | None = None, state: dict = None, label: str | None = None, help_str: str = None, multiple: bool = False, **kwargs, ): state = state or {} # TODO: bug if default == 0 default = default or state.get(kwarg_label) # mpc.RangeSlider requires a domain to be specified slider_kwargs = {"domain": [0, default * 2]} slider_kwargs.update(**kwargs) if multiple: slider_input = mpc.DualRangeSlider( id=self.id(kwarg_label, is_kwarg=True, hint="slider"), value=default, **slider_kwargs, ) else: slider_input = mpc.RangeSlider( id=self.id(kwarg_label, is_kwarg=True, hint="slider"), value=default, **slider_kwargs, ) return add_label_help(slider_input, label, help_str)
[docs] def get_bool_input( self, kwarg_label: str, default: bool | None = None, state: dict | None = None, label: str | None = None, help_str: str = None, **kwargs, ): """ For Python classes which take boolean values as inputs, this will generate a corresponding Dash input layout. :param kwarg_label: The name of the corresponding Python input, this is used to name the component. :param label: A description for this input. :param default: A default value for this input. :param state: Used to set default state for this input, use a dict with the kwarg_label as a key and the default value as a value. Ignored if `default` is set. It can be useful to use `state` if you want to set defaults for multiple inputs from a single dictionary. :param help_str: Text for a tooltip when hovering over label. :return: a Dash layout """ state = state or {} default = default or state.get(kwarg_label) or False bool_input = mpc.Switch( id=self.id(kwarg_label, is_kwarg=True, hint="bool"), value=True if default else False, hasLabel=True, **kwargs, ) return add_label_help(bool_input, label, help_str)
[docs] def get_choice_input( self, kwarg_label: str, default: str | None = None, state: dict | None = None, label: str | None = None, help_str: str = None, options: list[dict] | None = None, clearable: bool = False, **kwargs, ): """ For Python classes which take pre-defined values as inputs, this will generate a corresponding input layout using mpc.Select. :param kwarg_label: The name of the corresponding Python input, this is used to name the component. :param label: A description for this input. :param default: A default value for this input. :param state: Used to set default state for this input, use a dict with the kwarg_label as a key and the default value as a value. Ignored if `default` is set. It can be useful to use `state` if you want to set defaults for multiple inputs from a single dictionary. :param help_str: Text for a tooltip when hovering over label. :param options: Options to choose from, as per dcc.Dropdown :param clearable: If True, will allow Dropdown to be cleared after a selection is made. :return: a Dash layout """ state = state or {} default = default or state.get(kwarg_label) option_input = mpc.Select( id=self.id(kwarg_label, is_kwarg=True, hint="literal"), options=options if options else [], value=default, isClearable=clearable, arbitraryProps={**kwargs}, ) return add_label_help(option_input, label, help_str)
[docs] def get_dict_input( self, kwarg_label: str, default: Any | None = None, state: dict | None = None, label: str | None = None, help_str: str = None, key_name: str = "key", value_name: str = "value", ): """ :param kwarg_label: :param default: :param state: :param label: :param help_str: :param key_name: :param value_name: :return: """ state = state or {} default = default or state.get(kwarg_label) or {} dict_input = dt.DataTable( id=self.id(kwarg_label, is_kwarg=True, hint="dict"), columns=[ {"id": "key", "name": key_name}, {"id": "value", "name": value_name}, ], data=[{"key": k, "value": v} for k, v in default.items()], editable=True, persistence=False, ) return add_label_help(dict_input, label, help_str)
[docs] def get_kwarg_id(self, kwarg_name) -> dict: """ :param kwarg_name: :return: """ return { "component_id": self._id, "kwarg_label": kwarg_name, "idx": ALL, "hint": ALL, }
[docs] def get_all_kwargs_id(self) -> dict: """ :return: """ return {"component_id": self._id, "kwarg_label": ALL, "idx": ALL, "hint": ALL}
[docs] def reconstruct_kwarg_from_state(self, state, kwarg_name): return self.reconstruct_kwargs_from_state( state=state, kwarg_labels=[kwarg_name] )[kwarg_name]
[docs] def reconstruct_kwargs_from_state(self, state=None, kwarg_labels=None) -> dict: """ Generate :param state: optional, a Dash callback context input or state :param kwarg_labels: optional, parse only a specific kwarg or list of kwargs :return: A dictionary of keyword arguments with their values """ if not state: state = {} state.update(dash.callback_context.inputs) state.update(dash.callback_context.states) kwargs = {} for k, v in state.items(): # TODO: hopefully this will be less hacky in future Dash versions # remove trailing ".value" and convert back into dictionary # need to sort k somehow ... try: d = loads(k[: -len(".value")]) except JSONDecodeError: continue kwarg_label = d["kwarg_label"] if kwarg_labels and kwarg_label not in kwarg_labels: continue try: k_type = literal_eval(d["hint"]) except ValueError: k_type = d["hint"] idx = literal_eval(d["idx"]) try: if isinstance(k_type, tuple): # matrix or vector if kwarg_label not in kwargs: kwargs[kwarg_label] = np.empty(k_type) v = literal_eval(str(v)) if (v is not None) and (kwargs[kwarg_label] is not None): # print("debugging", kwargs, kwarg_label, idx, v) if isinstance(v, list): print( "This shouldn't happen! Debug required.", kwarg_label, idx, v, ) kwargs[kwarg_label][idx] = None else: kwargs[kwarg_label][idx] = v else: # require all elements to have value, otherwise set # entire kwarg to None kwargs[kwarg_label] = None elif k_type == "literal": try: kwargs[kwarg_label] = literal_eval(str(v)) except (ValueError, SyntaxError): kwargs[kwarg_label] = str(v) elif k_type == "bool": kwargs[kwarg_label] = v elif k_type == "slider": kwargs[kwarg_label] = v elif k_type == "dict": pass except Exception as exc: # Not raised intentionally but if you notice this in logs please investigate. print("This is a problem, debug required.", exc, d, v, type(v)) for k, v in kwargs.items(): if isinstance(v, np.ndarray): kwargs[k] = v.tolist() if SETTINGS.DEBUG_MODE: print(type(self).__name__, "kwargs", kwargs) return kwargs
[docs] @staticmethod def datauri_from_fig( fig, fmt: str = "png", width: int = 600, height: int = 400, scale: int = 4 ) -> str: """ Generate a data URI from a Plotly Figure. :param fig: Plotly Figure object or corresponding dictionary :param fmt: "png", "jpg", etc. (see PlotlyScope for supported formats) :param width: width in pixels :param height: height in pixels :param scale: scale factor :return: """ from kaleido.scopes.plotly import PlotlyScope scope = PlotlyScope() output = scope.transform( fig, format=fmt, width=width, height=height, scale=scale ) image = b64encode(output).decode("ascii") return f"data:image/{fmt};base64,{image}"
[docs] def get_figure_placeholder(self, figure_id: str) -> html.Div: """ Get a layout to act as a placeholder for an interactive figure. When used with `generate_static_figure_callbacks`, and assuming kaleido is installed on the server, a static image placeholder will be generated. :return: """ return html.Div( [ html.Div( [Loading(id=self.id(f"{figure_id}-wrapped-figure-inner"))], id=self.id("wrapped-figure-outer"), ), Button( [Icon(kind="chart-pie"), html.Span(), "Make Plot Interactive"], kind="primary", id=self.id(f"{figure_id}-wrapped-figure-button"), ), ] )