Source code for crystal_toolkit.core.mpcomponent
from __future__ import annotations
import logging
from abc import ABC
from ast import literal_eval
from base64 import b64encode
from collections import defaultdict
from itertools import chain
from json import JSONDecodeError, dumps, loads
from typing import Any, Literal
import dash
import dash_mp_components as mpc
import numpy as np
from dash import dash_table as dt
from dash import dcc, html
from dash.dependencies import ALL
from flask_caching import Cache
from monty.json import MontyDecoder, MSONable
from crystal_toolkit import __version__ as ct_version
from crystal_toolkit.helpers.layouts import Button, Icon, Loading, add_label_help
from crystal_toolkit.settings import SETTINGS
# fallback cache if Redis etc. isn't set up
null_cache = Cache(config={"CACHE_TYPE": "null"})
# Crystal Toolkit namespace, added to the start of all ids
# so we can see which layouts have been added by Crystal Toolkit
CT_NAMESPACE = "CT"
[docs]class MPComponent(ABC):
    """
    The abstract base class for an MPComponent. MPComponent
    is designed to help render an MSONable object.
    """
    # reference to global Dash app
    app = None
    # reference to Flask cache
    cache = None
    # used to track all dcc.Stores required for all MPComponents to work
    # keyed by the MPComponent id
    _app_stores_dict: dict[str, list[dcc.Store]] = defaultdict(list)
    # used to track what individual Dash components are defined
    # by this MPComponent
    _all_id_basenames: set[str] = set()
    # used to defer generation of callbacks until app.layout defined
    # can be helpful to callback exceptions retained
    _callbacks_to_generate: set[MPComponent] = set()
[docs]    @staticmethod
    def register_app(app: dash.Dash):
        """
        This method must be called at least once in your Crystal
        Toolkit Dash app if you want to enable interactivity with the
        MPComponents. The "app" variable is a special global
        variable used by Dash/Flask, and registering it with
        MPComponent allows callbacks to be registered with the
        app on instantiation.
        Args:
            app: a Dash app instance
        """
        MPComponent.app = app
        # add metadata
        app.config.meta_tags.append(
            {
                "name": "generator",
                "content": f"Crystal Toolkit {ct_version} (Materials Project)",
            }
        )
        # set default title, but respect the user if they override it
        if app.title == "Dash":
            app.title = "Crystal Toolkit"
[docs]    @staticmethod
    def register_cache(cache: Cache):
        """
        This method must be called at least once in your
        Crystal Toolkit Dash app if you want to enable
        callback caching. Callback caching is one of the
        easiest ways to see significant performance
        improvements, especially for callbacks that are
        computationally expensive.
        Args:
            cache: a flask_caching Cache instance
        """
        if cache:
            MPComponent.cache = cache
        else:
            MPComponent.cache = Cache(
                MPComponent.app.server, config={"CACHE_TYPE": "simple"}
            )
[docs]    @staticmethod
    def crystal_toolkit_layout(layout: html.Div) -> html.Div:
        if not MPComponent.app:
            raise ValueError(
                "Please register the Dash app with Crystal Toolkit "
                "using register_app()."
            )
        # layout_str = str(layout)
        stores_to_add = []
        for basename in MPComponent._all_id_basenames:
            # can use "if basename in layout_str:" to restrict to components present in initial layout
            # this would cause bugs for components displayed dynamically
            stores_to_add += MPComponent._app_stores_dict[basename]
        layout.children += stores_to_add
        # set app.layout to layout so that callbacks can be validated
        MPComponent.app.layout = layout
        for component in MPComponent._callbacks_to_generate:
            component.generate_callbacks(MPComponent.app, MPComponent.cache)
        return layout
[docs]    @staticmethod
    def register_crystal_toolkit(app, layout, cache=None):
        MPComponent.register_app(app)
        MPComponent.register_cache(cache)
        app.config["suppress_callback_exceptions"] = True
        app.layout = MPComponent.crystal_toolkit_layout(layout)
[docs]    @staticmethod
    def all_app_stores() -> html.Div:
        """
        This must be included somewhere in your
        Crystal Toolkit Dash app's layout for
        interactivity to work. This is a hidden element
        that contains the MSON for each MPComponent.
        Returns: a html.Div Dash Layout
        """
        return html.Div(
            list(chain.from_iterable(MPComponent._app_stores_dict.values()))
        )
    def __init__(
        self,
        default_data: MSONable | dict | str | None = None,
        id: str | None = None,
        links: dict[str, str] | None = None,
        storage_type: Literal["memory", "local", "session"] = "memory",
        disable_callbacks: bool = False,
    ):
        """
        The abstract base class for an MPComponent.
        The MPComponent is designed to help render any MSONable object,
        for example many of the objects in pymatgen (Structure, PhaseDiagram, etc.)
        To instantiate an MPComponent, you will need to create it outside
        of your Dash app layout:
        my_component = MPComponent(my_msonable_object)
        Then, inside the app.layout, you can include the component's layout
        anywhere you choose: my_component.layout
        If you want the layouts to be interactive, i.e. to respond to callbacks,
        you have to also use the MPComponent.register_app(app) method in your app,
        and also include MPComponent.all_app_stores in your app.layout (an
        invisible layout that contains the MSON itself).
        If you do not want the layouts to be interactive, set disable_callbacks
        to True to prevent errors.
        If including multiple MPComponents of the same type, make sure
        to set the id field to a unique value, as you would in any other
        Dash component.
        When sub-classing MPComponent, the most important methods to implement
        are _sub_layouts and generate_callbacks().
        Args:
            default_data: initial contents for the component, can be None
            id: a unique id, required if multiple of the same type of
            MPComponent are included in an app
            links: if set, will set store contents from the stores of another
            component to reduce unnecessary callbacks and duplication of data,
            note that links are one directional only and specific the origin
            stores, e.g. set {"default": my_other_component.id()} to fill this
            component's default store contents from the other component's default store,
            or {"graph": my_other_component.id("graph")} to fill this component's
            "graph" store from another component's "graph" store
            storage_type: whether to persist contents of component through
            browser refresh or browser sessions, use with caution, defaults
            to "memory" so component store will be emptied on refresh, see
            dcc.Store documentation for more information
            disable_callbacks: if True, will not generate callbacks, useful
            for static layouts or returning new MPComponents dynamically where
            generating callbacks are not possible due to limitations of Dash
        """
        # ensure ids are unique
        # Note: shadowing Python built-in here, but only because Dash does it...
        # TODO: do something else here
        if id is None:
            # TODO: this could lead to duplicate ids and an error, but if
            # setting random ids, this could also lead to undefined behavior
            id = f"{CT_NAMESPACE}{type(self).__name__}"
        elif not id.startswith(CT_NAMESPACE):
            id = f"{CT_NAMESPACE}{id}"
        MPComponent._all_id_basenames.add(id)
        self._id = id
        self._all_ids: set[str] = set()
        self._stores = {}
        self._initial_data = {}
        self.links = links or {}
        self.create_store(
            name="default", initial_data=default_data, storage_type=storage_type
        )
        self.links["default"] = self.id()
        if not disable_callbacks:
            # callbacks generated as final step by crystal_toolkit_layout()
            self._callbacks_to_generate.add(self)
        self.logger = logging.getLogger(type(self).__name__)
[docs]    def id(
        self,
        name: str = "default",
        is_kwarg: bool = False,
        idx=False,
        hint=None,
        is_store: bool = False,
    ) -> str | dict[str, str]:
        """
        Generate an id from a name combined with the
        base id of the MPComponent itself, useful for generating
        ids of individual components in the layout.
        In the special case of the id of an element that is used to re-construct
        a keyword argument for a specific class, it will store information necessary
        to reconstruct that keyword argument (e.g. its type hint and, in the case of
        a vector or matrix, the corresponding index).
        A hint could be a tuple for a numpy array of that shape, e.g. (3, 3) for a 3x3 matrix,
        (1, 3) for a vector, or "literal" to parse kwarg value using ast.literal_eval, or "bool"
        to parse a boolean value. In future iterations, we may be able to replace this with native
        Python type hints. The problem here is being able to specify array shape where appropriate.
        Args:
            name: e.g. "default"
        Returns: e.g. "MPComponent_default"
        """
        if name in self._stores:
            is_store = True
        if is_kwarg:
            return {
                "component_id": self._id,
                "kwarg_label": name,
                "idx": str(idx),
                "hint": str(hint),
            }
        # if we're linking to another component, return that id
        if name in self.links:
            return self.links[name]
        # otherwise create a new id
        self._all_ids.add(name)
        if name != "default":
            name = f"{self._id}_{name}"
        else:
            name = f"{self._id}"
        return name
        if is_store:
            return name
        else:
            return {"id": name}
[docs]    def create_store(
        self,
        name: str,
        initial_data: MSONable | dict | str | None = None,
        storage_type: Literal["memory", "local", "session"] = "memory",
        debug_clear: bool = False,
    ):
        """
        Generate a dcc.Store to hold something (MSONable object, Dict
        or string), and register it so that it will be included in the
        Dash app automatically.
        The initial data will be stored in a class attribute as
        self._initial_data[name].
        Args:
            name: name for the store
            initial_data: initial data to include
            storage_type: as in dcc.Store
            debug_clear: set to True to empty the store if using
            persistent storage
        """
        # if we're linking to another component, do not create a new store
        if name in self.links:
            return
        store = dcc.Store(
            id=self.id(name, is_store=True),
            data=initial_data,
            storage_type=storage_type,
            clear_data=debug_clear,
        )
        self._stores[name] = store
        self._initial_data[name] = initial_data
        MPComponent._app_stores_dict[self.id()].append(store)
    @property
    def initial_data(self):
        """
        :return: Initial data for all the stores defined by component,
        keyed by store name.
        """
        return self._initial_data
[docs]    @staticmethod
    def from_data(data):
        """
        Converts the contents of a dcc.Store back into a Python object.
        :param data: contents of a dcc.Store created by to_data
        :return: a Python object
        """
        return loads(dumps(data), cls=MontyDecoder)
    @property
    def all_stores(self) -> list[str]:
        """
        :return: List of all store ids generated by this component
        """
        return list(self._stores)
    @property
    def all_ids(self) -> list[str]:
        """
        :return: List of all ids generated by this component
        """
        return list(
            component_id
            for component_id in self._all_ids
            if component_id not in self.all_stores
        )
    def __repr__(self):
        return f"{self.id()}<{type(self).__name__}>"
    def __str__(self):
        ids = "\n".join(
            [f"* {component_id}  " for component_id in sorted(self.all_ids)]
        )
        stores = "\n".join([f"* {store}  " for store in sorted(self.all_stores)])
        layouts = "\n".join([f"* {layout}  " for layout in sorted(self._sub_layouts)])
        return f"""{self.id()}<{type(self).__name__}>  \n
IDs:  \n{ids}  \n
Stores:  \n{stores}  \n
Sub-layouts:  \n{layouts}"""
    @property
    def _sub_layouts(self):
        """
        Layouts associated with this component, available for book-keeping
        if your component is complex, so that the layout() method is just
        assembles individual sub-layouts.
        :return: A dictionary with names of layouts as keys (str) and Dash
        layouts (e.g. html.Div) as values.
        """
        return {}
[docs]    def layout(self) -> html.Div:
        """
        :return: A Dash layout for the full component. Basic implementation
        provided, but should in general be overridden.
        """
        return html.Div(list(self._sub_layouts.values()))
[docs]    def generate_callbacks(self, app, cache):
        """
        Generate all callbacks associated with the layouts in this app. Assume
        that "suppress_callback_exceptions" is True, since it is not always
        guaranteed that all layouts will be displayed to the end user at all
        times, but it's important the callbacks are defined on the server.
        """
        return None
[docs]    def get_numerical_input(
        self,
        kwarg_label: str,
        default: int | float | list | None = None,
        state: dict | None = None,
        label: str | None = None,
        help_str: str = None,
        is_int: bool = False,
        shape: tuple[int, ...] = (),
        **kwargs,
    ):
        """
        For Python classes which take matrices as inputs, this will generate
        a corresponding Dash input layout.
        :param kwarg_label: The name of the corresponding Python input, this is used
        to name the component.
        :param label: A description for this input.
        :param default: A default value for this input.
        :param state: Used to set default state for this input, use a dict with the kwarg_label as a key
        and the default value as a value. Ignored if `default` is set. It can be useful to use
        `state` if you want to set defaults for multiple inputs from a single dictionary.
        :param help_str: Text for a tooltip when hovering over label.
        :param is_int: if True, will use a numeric input
        :param shape: (3, 3) for matrix, (1, 3) for vector, (1, 1) for scalar
        :return: a Dash layout
        """
        state = state or {}
        default = np.full(shape, default or state.get(kwarg_label))
        default = np.reshape(default, shape)
        style = {
            "textAlign": "center",
            # shorter default width if matrix or vector
            "width": "5rem",
            "marginRight": "0.2rem",
            "marginBottom": "0.2rem",
            "height": "36px",
        }
        if "style" in kwargs:
            style.update(kwargs["style"])
            del kwargs["style"]
        def matrix_element(idx, value=0):
            # TODO: maybe move element out of the name
            mid = self.id(kwarg_label, is_kwarg=True, idx=idx, hint=shape)
            if isinstance(value, np.ndarray):
                value = value.item()
            if not is_int:
                return dcc.Input(
                    id=mid,
                    inputMode="numeric",
                    debounce=True,
                    className="input",
                    style=style,
                    value=float(value) if value is not None else None,
                    persistence=True,
                    type="number",
                    **kwargs,
                )
            else:
                return dcc.Input(
                    id=mid,
                    inputMode="numeric",
                    debounce=True,
                    className="input",
                    style=style,
                    value=int(value) if value is not None else None,
                    persistence=True,
                    type="number",
                    step=1,
                    **kwargs,
                )
        # dict of row indices, column indices to element
        matrix_contents = defaultdict(dict)
        # determine what individual input boxes we need
        # note that shape = () for floats, shape = (3,) for vectors
        # but we may also need to accept input for e.g. (3, 1)
        it = np.nditer(default, flags=["multi_index", "refs_ok"])
        while not it.finished:
            idx = it.multi_index
            row = (idx[1] if len(idx) > 1 else 0,)
            column = idx[0] if len(idx) > 0 else 0
            matrix_contents[row][column] = matrix_element(idx, value=it[0])
            it.iternext()
        # arrange the input boxes in two dimensions (rows, columns)
        matrix_div_contents = []
        print("matrix_contents", matrix_contents)
        for column_idx in sorted(matrix_contents):
            row = []
            for row_idx in sorted(matrix_contents[column_idx]):
                row.append(matrix_contents[column_idx][row_idx])
            matrix_div_contents.append(html.Div(row))
        matrix = html.Div(matrix_div_contents)
        return add_label_help(matrix, label, help_str)
[docs]    def get_slider_input(
        self,
        kwarg_label: str,
        default: Any | None = None,
        state: dict = None,
        label: str | None = None,
        help_str: str = None,
        multiple: bool = False,
        **kwargs,
    ):
        state = state or {}
        # TODO: bug if default == 0
        default = default or state.get(kwarg_label)
        # mpc.RangeSlider requires a domain to be specified
        slider_kwargs = {"domain": [0, default * 2]}
        slider_kwargs.update(**kwargs)
        if multiple:
            slider_input = mpc.DualRangeSlider(
                id=self.id(kwarg_label, is_kwarg=True, hint="slider"),
                value=default,
                **slider_kwargs,
            )
        else:
            slider_input = mpc.RangeSlider(
                id=self.id(kwarg_label, is_kwarg=True, hint="slider"),
                value=default,
                **slider_kwargs,
            )
        return add_label_help(slider_input, label, help_str)
[docs]    def get_bool_input(
        self,
        kwarg_label: str,
        default: bool | None = None,
        state: dict | None = None,
        label: str | None = None,
        help_str: str = None,
        **kwargs,
    ):
        """
        For Python classes which take boolean values as inputs, this will generate
        a corresponding Dash input layout.
        :param kwarg_label: The name of the corresponding Python input, this is used
        to name the component.
        :param label: A description for this input.
        :param default: A default value for this input.
        :param state: Used to set default state for this input, use a dict with the
            kwarg_label as a key
        and the default value as a value. Ignored if `default` is set. It can be useful
            to use `state` if you want to set defaults for multiple inputs from a single dictionary.
        :param help_str: Text for a tooltip when hovering over label.
        :return: a Dash layout
        """
        state = state or {}
        default = default or state.get(kwarg_label) or False
        bool_input = mpc.Switch(
            id=self.id(kwarg_label, is_kwarg=True, hint="bool"),
            value=True if default else False,
            hasLabel=True,
            **kwargs,
        )
        return add_label_help(bool_input, label, help_str)
[docs]    def get_choice_input(
        self,
        kwarg_label: str,
        default: str | None = None,
        state: dict | None = None,
        label: str | None = None,
        help_str: str = None,
        options: list[dict] | None = None,
        clearable: bool = False,
        **kwargs,
    ):
        """
        For Python classes which take pre-defined values as inputs, this will generate
        a corresponding input layout using mpc.Select.
        :param kwarg_label: The name of the corresponding Python input, this is used
        to name the component.
        :param label: A description for this input.
        :param default: A default value for this input.
        :param state: Used to set default state for this input, use a dict with the kwarg_label as a key
        and the default value as a value. Ignored if `default` is set. It can be useful to use
        `state` if you want to set defaults for multiple inputs from a single dictionary.
        :param help_str: Text for a tooltip when hovering over label.
        :param options: Options to choose from, as per dcc.Dropdown
        :param clearable: If True, will allow Dropdown to be cleared after a selection is made.
        :return: a Dash layout
        """
        state = state or {}
        default = default or state.get(kwarg_label)
        option_input = mpc.Select(
            id=self.id(kwarg_label, is_kwarg=True, hint="literal"),
            options=options if options else [],
            value=default,
            isClearable=clearable,
            arbitraryProps={**kwargs},
        )
        return add_label_help(option_input, label, help_str)
[docs]    def get_dict_input(
        self,
        kwarg_label: str,
        default: Any | None = None,
        state: dict | None = None,
        label: str | None = None,
        help_str: str = None,
        key_name: str = "key",
        value_name: str = "value",
    ):
        """
        :param kwarg_label:
        :param default:
        :param state:
        :param label:
        :param help_str:
        :param key_name:
        :param value_name:
        :return:
        """
        state = state or {}
        default = default or state.get(kwarg_label) or {}
        dict_input = dt.DataTable(
            id=self.id(kwarg_label, is_kwarg=True, hint="dict"),
            columns=[
                {"id": "key", "name": key_name},
                {"id": "value", "name": value_name},
            ],
            data=[{"key": k, "value": v} for k, v in default.items()],
            editable=True,
            persistence=False,
        )
        return add_label_help(dict_input, label, help_str)
[docs]    def get_kwarg_id(self, kwarg_name) -> dict:
        """
        :param kwarg_name:
        :return:
        """
        return {
            "component_id": self._id,
            "kwarg_label": kwarg_name,
            "idx": ALL,
            "hint": ALL,
        }
[docs]    def get_all_kwargs_id(self) -> dict:
        """
        :return:
        """
        return {"component_id": self._id, "kwarg_label": ALL, "idx": ALL, "hint": ALL}
[docs]    def reconstruct_kwarg_from_state(self, state, kwarg_name):
        return self.reconstruct_kwargs_from_state(
            state=state, kwarg_labels=[kwarg_name]
        )[kwarg_name]
[docs]    def reconstruct_kwargs_from_state(self, state=None, kwarg_labels=None) -> dict:
        """
        Generate
        :param state: optional, a Dash callback context input or state
        :param kwarg_labels: optional, parse only a specific kwarg or list of kwargs
        :return: A dictionary of keyword arguments with their values
        """
        if not state:
            state = {}
            state.update(dash.callback_context.inputs)
            state.update(dash.callback_context.states)
        kwargs = {}
        for k, v in state.items():
            # TODO: hopefully this will be less hacky in future Dash versions
            # remove trailing ".value" and convert back into dictionary
            # need to sort k somehow ...
            try:
                d = loads(k[: -len(".value")])
            except JSONDecodeError:
                continue
            kwarg_label = d["kwarg_label"]
            if kwarg_labels and kwarg_label not in kwarg_labels:
                continue
            try:
                k_type = literal_eval(d["hint"])
            except ValueError:
                k_type = d["hint"]
            idx = literal_eval(d["idx"])
            try:
                if isinstance(k_type, tuple):
                    # matrix or vector
                    if kwarg_label not in kwargs:
                        kwargs[kwarg_label] = np.empty(k_type)
                    v = literal_eval(str(v))
                    if (v is not None) and (kwargs[kwarg_label] is not None):
                        # print("debugging", kwargs, kwarg_label, idx, v)
                        if isinstance(v, list):
                            print(
                                "This shouldn't happen! Debug required.",
                                kwarg_label,
                                idx,
                                v,
                            )
                            kwargs[kwarg_label][idx] = None
                        else:
                            kwargs[kwarg_label][idx] = v
                    else:
                        # require all elements to have value, otherwise set
                        # entire kwarg to None
                        kwargs[kwarg_label] = None
                elif k_type == "literal":
                    try:
                        kwargs[kwarg_label] = literal_eval(str(v))
                    except (ValueError, SyntaxError):
                        kwargs[kwarg_label] = str(v)
                elif k_type == "bool":
                    kwargs[kwarg_label] = v
                elif k_type == "slider":
                    kwargs[kwarg_label] = v
                elif k_type == "dict":
                    pass
            except Exception as exc:
                # Not raised intentionally but if you notice this in logs please investigate.
                print("This is a problem, debug required.", exc, d, v, type(v))
        for k, v in kwargs.items():
            if isinstance(v, np.ndarray):
                kwargs[k] = v.tolist()
        if SETTINGS.DEBUG_MODE:
            print(type(self).__name__, "kwargs", kwargs)
        return kwargs
[docs]    @staticmethod
    def datauri_from_fig(
        fig, fmt: str = "png", width: int = 600, height: int = 400, scale: int = 4
    ) -> str:
        """
        Generate a data URI from a Plotly Figure.
        :param fig: Plotly Figure object or corresponding dictionary
        :param fmt: "png", "jpg", etc. (see PlotlyScope for supported formats)
        :param width: width in pixels
        :param height: height in pixels
        :param scale: scale factor
        :return:
        """
        from kaleido.scopes.plotly import PlotlyScope
        scope = PlotlyScope()
        output = scope.transform(
            fig, format=fmt, width=width, height=height, scale=scale
        )
        image = b64encode(output).decode("ascii")
        return f"data:image/{fmt};base64,{image}"
[docs]    def get_figure_placeholder(self, figure_id: str) -> html.Div:
        """
        Get a layout to act as a placeholder for an interactive figure.
        When used with `generate_static_figure_callbacks`, and assuming
        kaleido is installed on the server, a static image placeholder will
        be generated.
        :return:
        """
        return html.Div(
            [
                html.Div(
                    [Loading(id=self.id(f"{figure_id}-wrapped-figure-inner"))],
                    id=self.id("wrapped-figure-outer"),
                ),
                Button(
                    [Icon(kind="chart-pie"), html.Span(), "Make Plot Interactive"],
                    kind="primary",
                    id=self.id(f"{figure_id}-wrapped-figure-button"),
                ),
            ]
        )