Source code for crystal_toolkit.components.transformations.core
from __future__ import annotations
import traceback
import warnings
import dash
import dash_daq as daq
from dash import dcc, html
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate
from pymatgen.transformations.transformation_abc import AbstractTransformation
from crystal_toolkit.core.mpcomponent import MPComponent
from crystal_toolkit.helpers.layouts import (
Column,
Columns,
MessageBody,
MessageContainer,
MessageHeader,
Reveal,
)
from crystal_toolkit.settings import SETTINGS
[docs]class TransformationComponent(MPComponent):
def __init__(self, input_structure_component_id: str, *args, **kwargs):
if type(self).__name__ != f"{self.transformation.__name__}Component":
# sanity check, enforcing conventions
raise NameError(
f"Class has to be named corresponding to the underlying "
f"transformation name: {self.transformation.__name__}Component"
)
super().__init__(
*args, links={"input_structure": input_structure_component_id}, **kwargs
)
self.create_store(
"transformation_args_kwargs", initial_data={"args": [], "kwargs": {}}
)
@property
def is_one_to_many(self) -> bool:
"""
This should reflect the underlying transformation.
"""
# need to initialize transformation to access property, which isn't
# possible in all cases without necessary kwargs, which is why
# we duplicate the property here
return False
@property
def _sub_layouts(self):
enable = daq.BooleanSwitch(
id=self.id("enable_transformation"),
style={"display": "inline-block", "vertical-align": "middle"},
)
message = html.Div(id=self.id("message"))
description = dcc.Markdown(self.description)
options = html.Div(self.options_layouts(), id=self.id("options"))
preview = dcc.Loading(id=self.id("preview"))
if self.is_one_to_many:
ranked_list = daq.NumericInput(
value=1, min=1, max=10, id=self.id("ranked_list")
)
else:
# if not 1-to-many, we don't need the control, we keep
# an empty container here to make the callbacks simpler
# since "ranked_list" will then always be present in layout
ranked_list = html.Div(id=self.id("ranked_list"))
return {
"options": options,
"description": description,
"enable": enable,
"message": message,
"preview": preview,
"ranked_list": ranked_list,
}
[docs] def container_layout(self, state=None, structure=None) -> html.Div:
"""
:return: Layout defining transformation and its options.
"""
container = MessageContainer(
[
MessageHeader(
html.Div(
[
self._sub_layouts["enable"],
html.Span(
self.title,
style={
"vertical-align": "middle",
"margin-left": "1rem",
},
),
]
)
),
MessageBody(
[
Columns(
[
Column(
[
self._sub_layouts["description"],
html.Br(),
html.Div(
self.options_layouts(
state=state, structure=structure
)
),
html.Br(),
self._sub_layouts["message"],
]
)
]
)
]
),
],
kind="dark",
id=self.id("container"),
)
return container
[docs] def options_layouts(self, state=None, structure=None) -> list[html.Div]:
"""
Return a layout to change the transformation options (that is,
that controls the args and kwargs that will be passed to pymatgen).
The "state" option is so that the controls can be populated appropriately
using existing args and kwargs, e.g. when restoring the control panel
from a previous state.
:param state: existing state in format {"args": [], "kwargs": {}}
:return:
"""
return [html.Div()]
@property
def transformation(self):
raise NotImplementedError
@property
def title(self):
raise NotImplementedError
@property
def description(self):
raise NotImplementedError
[docs] def get_preview_layout(self, struct_in, struct_out):
"""
Override this method to give a layout that previews the transformation.
Has beneficial side effect of priming the transformation cache when
entire transformation pipeline is enabled.
:param struct_in: input Structure
:param struct_out: transformed Structure
:return:
"""
return html.Div()
[docs] def generate_callbacks(self, app, cache):
@cache.memoize()
def apply_transformation(transformation_data, struct):
transformation = self.from_data(transformation_data)
error = None
try:
struct = transformation.apply_transformation(struct)
except Exception as exc:
error_title = (
f'Failed to apply "{type(transformation).__name__}" '
f"transformation: {exc}"
)
traceback_info = Reveal(
title=html.B("Traceback"),
children=[dcc.Markdown(traceback.format_exc())],
)
error = [error_title, traceback_info]
return struct, error
if SETTINGS.TRANSFORMATION_PREVIEWS:
# Transformation previews need to be included in layout too (see preview sublayout)
# Transformation previews need a full transformation pipeline replica (I/O heavy)
# Might abandon.
warnings.warn("Transformation previews under active development.")
@app.callback(
Output(self.id("preview"), "children"),
[Input(self.id(), "data"), Input(self.id("input_structure"), "data")],
)
def update_preview(transformation_data, input_structure):
if (not transformation_data) or (not input_structure):
return html.Div()
input_structure = self.from_data(input_structure)
output_structure, error = apply_transformation(
transformation_data, input_structure
)
if len(output_structure) > 64:
warning = html.Span(
f"The transformed crystal structure has {len(output_structure)} atoms "
f"and might take a moment to display."
)
return self.get_preview_layout(input_structure, output_structure)
@app.callback(
[
Output(self.id(), "data"),
Output(self.id("container"), "className"),
Output(self.id("message"), "children"),
Output(self.get_all_kwargs_id(), "disabled"),
],
[Input(self.id("enable_transformation"), "on")],
[State(self.get_all_kwargs_id(), "value")],
)
@cache.memoize(
timeout=60 * 60 * 24,
make_name=lambda x: f"{type(self).__name__}_{x}_cached",
)
def update_transformation(enabled, states):
# TODO: move callback inside AllTransformationsComponent for efficiency?
kwargs = self.reconstruct_kwargs_from_state(dash.callback_context.states)
# for debug
# print("transformation kwargs", kwargs)
if not enabled:
input_state = (False,) * len(states)
return None, "message is-dark", html.Div(), input_state
else:
input_state = (True,) * len(states)
try:
trans = self.transformation(**kwargs)
error = None
except Exception as exception:
trans = None
error = str(exception)
if error:
return (
trans,
"message is-warning",
html.Strong(f"Error: {error}"),
input_state,
)
else:
return trans, "message is-success", html.Div(), input_state
[docs]class AllTransformationsComponent(MPComponent):
def __init__(
self,
transformations: list[str] | None = None,
input_structure_component: MPComponent | None = None,
*args,
**kwargs,
):
"""
Create a component that can manage multiple transformations in a
user-defined order.
:param transformations: if provided, only offer a subset of available
transformations, provide as a string of the given transformation name
:param input_structure_component: will supply the structure to transform
"""
# get available transformations
subclasses = TransformationComponent.__subclasses__()
subclass_names = [s.__name__ for s in subclasses]
transformations = transformations or subclass_names
for name in transformations:
if name not in subclass_names:
warnings.warn(
f'Unknown transformation "{name}", choose from: {", ".join(subclass_names)}'
)
transformations = [t for t in subclasses if t.__name__ in transformations]
super().__init__(*args, **kwargs)
if input_structure_component:
self.links["input_structure"] = input_structure_component.id()
self.create_store("input_structure")
self.create_store("enabled-transformations", initial_data=[])
transformations = [
t(input_structure_component_id=self.id("input_structure"))
for t in transformations
]
self.transformations = {type(t).__name__: t for t in transformations}
@property
def _sub_layouts(self):
layouts = super()._sub_layouts
all_transformations = html.Div(
[
transformation.container_layout()
for name, transformation in self.transformations.items()
]
)
choices = dcc.Dropdown(
options=[
{"label": transformation.title, "value": name}
for name, transformation in self.transformations.items()
],
multi=True,
value=[],
placeholder="Select one or more transformations...",
id=self.id("choices"),
style={"max-width": "65vmin"},
persistence=True,
)
layouts.update({"all_transformations": all_transformations, "choices": choices})
return layouts
[docs] def layout(self):
return html.Div(
[
html.Div(
"Transform your crystal structure using the power of pymatgen.",
className="mpc-panel-description",
),
self._sub_layouts["choices"],
html.Br(),
html.Div(id=self.id("error")),
html.Div(id=self.id("transformation_options")),
]
)
[docs] def generate_callbacks(self, app, cache):
@cache.memoize()
def apply_transformation(transformation_data, struct):
transformation = self.from_data(transformation_data)
error = None
try:
if not isinstance(transformation, AbstractTransformation):
raise ValueError(
f"Can't run transformation: {transformation} is {type(transformation)}"
)
struct = transformation.apply_transformation(struct)
except Exception as exc:
error_title = html.Span(
f'Failed to apply "{type(transformation).__name__}" '
f"transformation: {exc}"
)
traceback_info = Reveal(
id=self.id("Error"),
title=html.B("Traceback"),
children=[dcc.Markdown(traceback.format_exc())],
)
error = [error_title, traceback_info]
return struct, error
@app.callback(
Output(self.id("transformation_options"), "children"),
[
Input(self.id("input_structure"), "data"),
Input(self.id("choices"), "value"),
],
[State(t.id(), "data") for t in self.transformations.values()],
)
def show_transformation_options(structure, values, *args):
# for debug
# print(dash.callback_context.triggered)
values = values or []
structure = self.from_data(structure)
transformation_options = html.Div(
[
self.transformations[name].container_layout(
state=state, structure=structure
)
for name, state in zip(values, args)
]
)
return [transformation_options]
@app.callback(
Output(self.id("enabled-transformations"), "data"),
Input(self.id("choices"), "value"),
)
def set_enabled_transformations(value):
"""
This is due to an unfortunate but noisy bug that
complains that this specific input is not present
in the layout on load.
"""
return value
# TODO: make an error store too
@app.callback(
# [
Output(self.id(), "data"),
# Output(self.id("error"), "children")],
[Input(t.id(), "data") for t in self.transformations.values()]
+ [
Input(self.id("input_structure"), "data"),
Input(self.id("enabled-transformations"), "data"),
],
)
def run_transformations(*args):
# do not update if we don't have a Structure to transform
if not args[-2]:
raise PreventUpdate
user_visible_transformations = args[-1]
struct = self.from_data(args[-2])
# for debug
# print("input struct", struct)
errors = []
transformations = []
for transformation in args[:-2]:
if transformation:
transformations.append(transformation)
if not transformations:
return struct # , html.Div()
for transformation_data in transformations:
# following our naming convention, only apply transformations
# that are user visible
# TODO: this should be changed
if (
f"{transformation_data['@class']}Component"
in user_visible_transformations
):
struct, error = apply_transformation(transformation_data, struct)
if error:
errors += error
if not errors:
error_msg = html.Div()
else:
errors = [
dcc.Markdown(
"Crystal Toolkit encountered an error when trying to "
"applying your chosen transformations. This is usually "
"because either the input crystal structure is not "
"suitable for the transformation, or the choice of "
"transformation settings is not appropriate. Consult "
"the pymatgen documentation for more information. \n"
""
"If you think this is a bug please report it. \n"
""
)
] + errors
error_msg = html.Div(
[
MessageContainer(
[
MessageHeader("Error applying transformations"),
MessageBody(errors),
],
kind="danger",
),
html.Br(),
]
)
# for debug
# print("transformed struct", struct)
return struct # , error_msg
# callback to take all transformations
# and also state of which transformations are user-visible (+ their order)
# apply them one by one with kwargs
# external error callback(?) for each transformation, have ext error + combine with trans error