import itertools
from multiprocessing import cpu_count
from warnings import warn
import dash_mp_components as mpc
from dash import callback_context, dcc, html
from dash.dependencies import Input, Output
from dash.exceptions import PreventUpdate
from pymatgen.analysis.chemenv.coordination_environments.chemenv_strategies import (
SimplestChemenvStrategy,
)
from pymatgen.analysis.chemenv.coordination_environments.coordination_geometries import (
AllCoordinationGeometries,
)
from pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder import (
LocalGeometryFinder,
)
from pymatgen.analysis.chemenv.coordination_environments.structure_environments import (
LightStructureEnvironments,
)
from pymatgen.analysis.graphs import MoleculeGraph, StructureGraph
from pymatgen.analysis.local_env import LocalStructOrderParams, cn_opt_params
from pymatgen.core.structure import Molecule, Structure
from pymatgen.ext.matproj import MPRester
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.util.string import unicodeify, unicodeify_species
from sklearn.preprocessing import normalize
from crystal_toolkit.components.structure import StructureMoleculeComponent
from crystal_toolkit.core.legend import Legend
from crystal_toolkit.core.panelcomponent import PanelComponent
from crystal_toolkit.helpers.layouts import (
H5,
Column,
Columns,
Label,
Loading,
cite_me,
get_data_list,
get_tooltip,
)
try:
from dscribe.descriptors import SOAP
from dscribe.kernels import REMatchKernel
except ImportError:
warn(
"Using dscribe SOAP and REMatchKernel requires the dscribe package "
"which was made optional since it in turn requires numba and numba "
"was a common source of installation issues."
)
SOAP = None
def _get_local_order_parameters(structure_graph, n):
"""
A copy of the method in pymatgen.analysis.local_env which
can operate on StructureGraph directly.
Calculate those local structure order parameters for
the given site whose ideal CN corresponds to the
underlying motif (e.g., CN=4, then calculate the
square planar, tetrahedral, see-saw-like,
rectangular see-saw-like order parameters).
Args:
structure_graph: StructureGraph object
n (int): site index.
Returns (Dict[str, float]):
A dict of order parameters (values) and the
underlying motif type (keys; for example, tetrahedral).
"""
# TODO: move me to pymatgen once stable
# code from @nisse3000, moved here from graphs to avoid circular
# import, also makes sense to have this as a general NN method
cn = structure_graph.get_coordination_of_site(n)
if cn in [int(k_cn) for k_cn in cn_opt_params]:
names = [k for k in cn_opt_params[cn]]
types = []
params = []
for name in names:
types.append(cn_opt_params[cn][name][0])
tmp = (
cn_opt_params[cn][name][1] if len(cn_opt_params[cn][name]) > 1 else None
)
params.append(tmp)
lostops = LocalStructOrderParams(types, parameters=params)
sites = [structure_graph.structure[n]] + [
connected_site.site
for connected_site in structure_graph.get_connected_sites(n)
]
lostop_vals = lostops.get_order_parameters(
sites, 0, indices_neighs=[i for i in range(1, cn + 1)]
)
d = {}
for i, lostop in enumerate(lostop_vals):
d[names[i]] = lostop
return d
else:
return None
[docs]class LocalEnvironmentPanel(PanelComponent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.create_store("graph")
self.create_store(
"display_options",
initial_data={"color_scheme": "Jmol", "color_scale": None},
)
@property
def title(self):
return "Local Environments"
@property
def description(self):
return "Analyze the local chemical environments in your crystal."
@property
def loading_text(self):
return "Analyzing environments"
[docs] def contents_layout(self) -> html.Div:
algorithm_choices = self.get_choice_input(
label="Analysis method",
kwarg_label="algorithm",
state={"algorithm": "chemenv"},
options=[
{"label": "ChemEnv", "value": "chemenv"},
{"label": "LocalEnv", "value": "localenv"},
{"label": "Bonding Graph", "value": "bondinggraph"},
{"label": "SOAP", "value": "soap"},
],
help_str="Choose an analysis method to examine the local chemical environment. "
"Several methods exist and there is no guaranteed correct answer, so try multiple!",
)
analysis = html.Div(id=self.id("analysis"))
return html.Div([algorithm_choices, html.Br(), analysis, html.Br()])
[docs] @staticmethod
def get_graph_data(graph, display_options):
color_scheme = display_options.get("color_scheme", "Jmol")
nodes = []
edges = []
struct_or_mol = StructureMoleculeComponent._get_struct_or_mol(graph)
legend = Legend(struct_or_mol, color_scheme=color_scheme)
for idx, node in enumerate(graph.graph.nodes()):
# TODO: fix for disordered
node_color = legend.get_color(
struct_or_mol[node].species.elements[0], site=struct_or_mol[node]
)
nodes.append(
{
"id": node,
"title": f"{struct_or_mol[node].species_string} site "
f"({graph.get_coordination_of_site(idx)} neighbors)",
"color": node_color,
}
)
for u, v, d in graph.graph.edges(data=True):
edge = {"from": u, "to": v, "arrows": ""}
to_jimage = d.get("to_jimage", (0, 0, 0))
# TODO: check these edge weights
if isinstance(struct_or_mol, Structure):
dist = struct_or_mol.get_distance(u, v, jimage=to_jimage)
else:
dist = struct_or_mol.get_distance(u, v)
edge["length"] = 50 * dist
if to_jimage != (0, 0, 0):
edge["arrows"] = "to"
label = f"{dist:.2f} Å to site at image vector {to_jimage}"
else:
label = f"{dist:.2f} Å between sites"
if label:
edge["title"] = label
# if 'weight' in d:
# label += f" {d['weight']}"
edges.append(edge)
return {"nodes": nodes, "edges": edges}
[docs] def generate_callbacks(self, app, cache):
super().generate_callbacks(app, cache)
@app.callback(
Output(self.id("analysis"), "children"),
[Input(self.get_kwarg_id("algorithm"), "value")],
)
def run_algorithm(algorithm):
algorithm = self.reconstruct_kwarg_from_state(
callback_context.inputs, "algorithm"
)
if algorithm == "chemenv":
state = {"distance_cutoff": 1.4, "angle_cutoff": 0.3}
description = (
"The ChemEnv algorithm is developed by David Waroquiers et al. to analyze "
'local chemical environments. In this interactive app, the "SimplestChemenvStrategy" '
'and "LightStructureEnvironments" are used. For more powerful analysis, please use '
"the *pymatgen* code directly. Note that this analysis determines its own bonds independent "
"of those shown in the main crystal visualizer."
)
distance_cutoff = self.get_numerical_input(
label="Distance cut-off",
kwarg_label="distance_cutoff",
state=state,
help_str="Defines search radius by considering any atom within a radius "
"of the minimum nearest neighbor distance multiplied by the distance "
"cut-off.",
shape=(),
)
angle_cutoff = self.get_numerical_input(
label="Angle cut-off",
kwarg_label="angle_cutoff",
state=state,
help_str="Defines a tolerance whereby a neighbor atom is excluded if the solid angle "
"circumscribed by its Voronoi face is smaller than the angle tolerance "
"multiplied by the largest solid angle present in the crystal.",
shape=(),
)
return html.Div(
[
dcc.Markdown(description),
html.Br(),
cite_me(
cite_text="How to cite ChemEnv",
doi="10.1107/S2052520620007994",
),
html.Br(),
distance_cutoff,
angle_cutoff,
html.Br(),
Loading(id=self.id("chemenv_analysis")),
]
)
elif algorithm == "localenv":
description = (
"The LocalEnv algorithm is developed by Nils Zimmerman et al. whereby "
"an 'order parameter' is calculated that measures how well that "
"environment matches an ideal polyhedra. The order parameter "
"is a number from zero to one, with one being a perfect match."
)
return html.Div(
[
dcc.Markdown(description),
html.Br(),
cite_me(
cite_text="How to cite LocalEnv",
doi="10.3389/fmats.2017.00034",
),
html.Br(),
Loading(id=self.id("localenv_analysis")),
]
)
elif algorithm == "bondinggraph":
description = (
"This is an alternative way to display the same bonds present in the "
"visualizer. Here, the bonding is displayed as a crystal graph, with "
"nodes as atoms and edges as bonds. The graph visualization is shown in an "
"abstract two-dimensional space."
)
return html.Div(
[
dcc.Markdown(description),
html.Br(),
Loading(id=self.id("bondinggraph_analysis")),
]
)
elif algorithm == "soap":
state = {
"rcut": 5.0,
"nmax": 2,
"lmax": 2,
"sigma": 0.2,
"crossover": True,
"average": False,
"rbf": "gto",
"alpha": 0.1,
"threshold": 1e-4,
"metric": "linear",
"normalize_kernel": True,
}
description = (
'The "Smooth Overlap of Atomic Positions" (SOAP) descriptor provides information on the local '
"atomic environment by encoding that environment as a power spectrum derived from the "
"spherical harmonics of atom-centered gaussian densities. The SOAP formalism is complex but is "
"described well in [Bartók et al.](https://doi.org/10.1103/PhysRevB.87.184115) "
"and the REMatch similarity kernel in [De et al.](https://doi.org/10.1039/c6cp00415f) "
"The implementation of SOAP in this "
"web app is provided by [DScribe](https://doi.org/10.1016/j.cpc.2019.106949). "
""
"SOAP kernels are commonly used in machine learning applications. This interface is provided to "
"help gain intuition and exploration of the behavior of SOAP kernels."
)
rcut = self.get_numerical_input(
label="Radial cut-off /Å",
kwarg_label="rcut",
state=state,
help_str="The radial cut-off that defines the local region being considered",
shape=(),
min=1.0001,
)
nmax = self.get_numerical_input(
label="N max.",
kwarg_label="nmax",
state=state,
help_str="Number of radial basis functions",
shape=(),
is_int=True,
min=1,
max=9,
)
lmax = self.get_numerical_input(
label="L max.",
kwarg_label="lmax",
state=state,
help_str="Maximum degree of spherical harmonics",
shape=(),
is_int=True,
min=1,
max=9,
)
sigma = self.get_numerical_input(
label="Sigma",
kwarg_label="sigma",
state=state,
help_str="The standard deviation of gaussians used to build atomic density",
shape=(),
min=0.00001,
)
rbf = self.get_choice_input(
label="Radial basis function",
kwarg_label="rbf",
state=state,
help_str="Polynomial basis is faster, spherical gaussian based was used in original formulation",
options=[
{"label": "Spherical gaussian basis", "value": "gto"},
{"label": "Polynomial basis", "value": "polynomial"},
],
style={"width": "16rem"}, # TODO: remove in-line style
)
crossover = self.get_bool_input(
label="Crossover",
kwarg_label="crossover",
state=state,
help_str="If enabled, the power spectrum will include all combinations of elements present.",
)
average = self.get_bool_input(
label="Average",
kwarg_label="average",
state=state,
help_str="If enabled, the SOAP vector will be averaged across all sites.",
)
alpha = self.get_numerical_input(
label="Alpha",
kwarg_label="alpha",
state=state,
help_str="Determines the entropic penalty in the REMatch kernel. As alpha goes to infinity, the "
"behavior of the REMatch kernel matches the behavior of the kernel where SOAP vectors "
"are averaged across all sites. As alpha goes to zero, the kernel matches the best match "
"kernel.",
shape=(),
min=0.00001,
)
threshold = self.get_numerical_input(
label="Sinkhorn threshold",
kwarg_label="threshold",
state=state,
help_str="Convergence threshold for the Sinkhorn algorithm. If values are too small, convergence "
"may not be possible, and calculation time will increase.",
shape=(),
)
metric = self.get_choice_input(
label="Metric",
kwarg_label="metric",
state=state,
help_str='See scikit-learn\'s documentation on "Pairwise metrics, Affinities and Kernels" '
"for an explanation of available metrics.",
options=[
# {"label": "Additive χ2", "value": "additive_chi2"}, # these seem to be unstable
# {"label": "Exponential χ2", "value": "chi2"},
{"label": "Linear", "value": "linear"},
{"label": "Polynomial", "value": "polynomial"},
{"label": "Radial basis function", "value": "rbf"},
{"label": "Laplacian", "value": "laplacian"},
{"label": "Sigmoid", "value": "sigmoid"},
{"label": "Cosine", "value": "cosine"},
],
style={"width": "16rem"}, # TODO: remove in-line style
)
normalize_kernel = self.get_bool_input(
label="Normalize",
kwarg_label="normalize_kernel",
state=state,
help_str="Whether or not to normalize the resulting similarity kernel.",
)
# metric_kwargs = self.get_dict_input()
return html.Div(
[
dcc.Markdown(description),
html.Br(),
H5("SOAP parameters"),
rcut,
nmax,
lmax,
sigma,
rbf,
crossover,
average,
html.Br(), # TODO: remove all html.Br(), add appropriate styles instead
html.Br(),
html.Div(id=self.id("soap_analysis")),
html.Br(),
html.Br(),
H5("Similarity metric parameters"),
html.Div(
"This will calculate structural similarity scores from materials in the "
"Materials Project in the same chemical system. Note that for large chemical "
"systems this step can take several minutes."
),
html.Br(),
alpha,
threshold,
metric,
# normalize_kernel,
html.Br(),
html.Br(),
Loading(id=self.id("soap_similarities")),
]
)
def _get_soap_graph(feature, label):
spectrum = {
"data": [
{
"coloraxis": "coloraxis",
# 'hovertemplate': 'x: %{x}<br>y: %{y}<br>color: %{z}<extra></extra>',
"type": "heatmap",
"z": feature.tolist(),
}
]
}
spectrum["layout"] = {
"xaxis": {"visible": False},
"yaxis": {"visible": False},
"paper_bgcolor": "rgba(0,0,0,0)",
"plot_bgcolor": "rgba(0,0,0,0)",
"coloraxis": {
"colorscale": [
[0.0, "#0d0887"],
[0.1111111111111111, "#46039f"],
[0.2222222222222222, "#7201a8"],
[0.3333333333333333, "#9c179e"],
[0.4444444444444444, "#bd3786"],
[0.5555555555555556, "#d8576b"],
[0.6666666666666666, "#ed7953"],
[0.7777777777777778, "#fb9f3a"],
[0.8888888888888888, "#fdca26"],
[1.0, "#f0f921"],
],
"showscale": False,
},
"margin": {"l": 0, "b": 0, "t": 0, "r": 0, "pad": 0},
# "height": 20*feature.shape[0], # for fixed size plots
# "width": 20*feature.shape[1]
}
return Columns(
[
Column(Label(label), size="1"),
Column(
dcc.Graph(
figure=spectrum,
config={"displayModeBar": False},
responsive=True,
style={"height": "60px"},
)
),
]
)
@app.callback(
Output(self.id("soap_analysis"), "children"),
[Input(self.id(), "data"), Input(self.get_all_kwargs_id(), "value")],
)
def update_soap_analysis(struct, all_kwargs):
if not struct:
raise PreventUpdate
if not SOAP:
return mpc.Markdown(
"This feature will not work unless `dscribe` is installed on the server."
)
struct = self.from_data(struct)
kwargs = self.reconstruct_kwargs_from_state(callback_context.inputs)
# TODO: make sure is_int kwarg information is enforced so that int() conversion is unnecessary
desc = SOAP(
species=[e.number for e in struct.composition.elements],
sigma=kwargs["sigma"],
rcut=kwargs["rcut"],
nmax=int(kwargs["nmax"]),
lmax=int(kwargs["lmax"]),
periodic=True,
crossover=kwargs["crossover"],
sparse=False,
average=kwargs["average"],
)
adaptor = AseAtomsAdaptor()
atoms = adaptor.get_atoms(struct)
feature = normalize(desc.create(atoms, n_jobs=cpu_count()))
return _get_soap_graph(feature, "SOAP vector for this material")
@cache.memoize(timeout=360)
def _get_all_structs_from_elements(elements):
structs = {}
all_chemsyses = []
for i in range(len(elements)):
for els in itertools.combinations(elements, i + 1):
all_chemsyses.append("-".join(sorted(els)))
with MPRester() as mpr:
docs = mpr.query(
{"chemsys": {"$in": all_chemsyses}}, ["task_id", "structure"]
)
structs.update({d["task_id"]: d["structure"] for d in docs})
return structs
@app.callback(
Output(self.id("soap_similarities"), "children"),
[Input(self.id(), "data"), Input(self.get_all_kwargs_id(), "value")],
)
def update_soap_similarities(struct, all_kwargs):
if not struct:
raise PreventUpdate
if not SOAP:
return mpc.Markdown(
"This feature will not work unless `dscribe` is installed on the server."
)
structs = {"input": self.from_data(struct)}
kwargs = self.reconstruct_kwargs_from_state(callback_context.inputs)
elements = [str(el) for el in structs["input"].composition.elements]
structs.update(_get_all_structs_from_elements(elements))
if not structs:
raise PreventUpdate
elements = {
elem for s in structs.values() for elem in s.composition.elements
}
# TODO: make sure is_int kwarg information is enforced so that int() conversion is unnecessary
desc = SOAP(
species=[e.number for e in elements],
sigma=kwargs["sigma"],
rcut=kwargs["rcut"],
nmax=int(kwargs["nmax"]),
lmax=int(kwargs["lmax"]),
periodic=True,
crossover=kwargs["crossover"],
sparse=False,
average=kwargs["average"],
)
adaptor = AseAtomsAdaptor()
atomss = {
mpid: adaptor.get_atoms(struct) for mpid, struct in structs.items()
}
print(f"Calculating {len(atomss)} SOAP vectors")
features = {
mpid: normalize(desc.create(atoms, n_jobs=cpu_count()))
for mpid, atoms in atomss.items()
}
re = REMatchKernel(
metric=kwargs["metric"],
alpha=kwargs["alpha"],
threshold=kwargs["threshold"],
# normalize_kernel=kwargs["normalize_kernel"],
)
print("Calculating similarity kernel")
similarities = {
mpid: re.get_global_similarity(
re.get_pairwise_matrix(features["input"], feature)
)
for mpid, feature in features.items()
if mpid != "input"
}
sorted_mpids = sorted(similarities, key=lambda x: -similarities[x])
print("Generating similarity graphs")
# TODO: was much slower using px.imshow (see prev commit)
all_graphs = [
_get_soap_graph(
features[mpid],
[
html.Span(
f"{unicodeify(structs[mpid].composition.reduced_formula)}"
),
dcc.Markdown(f"[{mpid}](https://materialsproject.org/{mpid})"),
html.Span(f"{similarities[mpid]:.5f}"),
],
)
for mpid in sorted_mpids
]
print("Returning similarity graphs")
return html.Div(all_graphs)
@app.callback(
Output(self.id("localenv_analysis"), "children"),
[Input(self.id("graph"), "data")],
)
def update_localenv_analysis(graph):
if not graph:
raise PreventUpdate
graph = self.from_data(graph)
return html.Div(
[
str(_get_local_order_parameters(graph, 0)),
html.Br(),
html.Small("This functionality is still under development."),
]
)
@app.callback(
Output(self.id("bondinggraph_analysis"), "children"),
[
Input(self.id("graph"), "data"),
Input(self.id("display_options"), "data"),
],
)
def update_bondinggraph_analysis(graph, display_options):
if not graph:
raise PreventUpdate
graph = self.from_data(graph)
display_options = self.from_data(display_options)
graph_data = self.get_graph_data(graph, display_options)
options = {
"interaction": {
"hover": True,
"tooltipDelay": 0,
"zoomView": False,
"dragView": False,
},
"edges": {
"smooth": {"type": "dynamic"},
"length": 250,
"color": {"inherit": "both"},
},
"physics": {
"solver": "forceAtlas2Based",
"forceAtlas2Based": {"avoidOverlap": 1.0},
"stabilization": {"fit": True},
},
}
return html.Div(
[mpc.GraphComponent(graph=graph_data, options=options)],
style={"width": "65vmin", "height": "65vmin"},
)
@app.callback(
Output(self.id("chemenv_analysis"), "children"),
[
Input(self.id(), "data"),
Input(self.get_kwarg_id("distance_cutoff"), "value"),
Input(self.get_kwarg_id("angle_cutoff"), "value"),
],
)
def get_chemenv_analysis(struct, distance_cutoff, angle_cutoff):
if not struct:
raise PreventUpdate
struct = self.from_data(struct)
kwargs = self.reconstruct_kwargs_from_state(callback_context.inputs)
distance_cutoff = kwargs["distance_cutoff"]
angle_cutoff = kwargs["angle_cutoff"]
# TODO: remove these brittle guard statements, figure out more robust way to handle multiple input types
if isinstance(struct, StructureGraph):
struct = struct.structure
def get_valences(struct):
valences = [getattr(site.specie, "oxi_state", None) for site in struct]
valences = [v for v in valences if v is not None]
if len(valences) == len(struct):
return valences
else:
return "undefined"
# decide which indices to present to user
sga = SpacegroupAnalyzer(struct)
symm_struct = sga.get_symmetrized_structure()
inequivalent_indices = [
indices[0] for indices in symm_struct.equivalent_indices
]
wyckoffs = symm_struct.wyckoff_symbols
lgf = LocalGeometryFinder()
lgf.setup_structure(structure=struct)
se = lgf.compute_structure_environments(
maximum_distance_factor=distance_cutoff + 0.01,
only_indices=inequivalent_indices,
valences=get_valences(struct),
)
strategy = SimplestChemenvStrategy(
distance_cutoff=distance_cutoff, angle_cutoff=angle_cutoff
)
lse = LightStructureEnvironments.from_structure_environments(
strategy=strategy, structure_environments=se
)
all_ce = AllCoordinationGeometries()
envs = []
unknown_sites = []
for index, wyckoff in zip(inequivalent_indices, wyckoffs):
datalist = {
"Site": unicodeify_species(struct[index].species_string),
"Wyckoff Label": wyckoff,
}
if not lse.neighbors_sets[index]:
unknown_sites.append(f"{struct[index].species_string} ({wyckoff})")
continue
# represent the local environment as a molecule
mol = Molecule.from_sites(
[struct[index]] + lse.neighbors_sets[index][0].neighb_sites
)
mol = mol.get_centered_molecule()
mg = MoleculeGraph.with_empty_graph(molecule=mol)
for i in range(1, len(mol)):
mg.add_edge(0, i)
view = html.Div(
[
StructureMoleculeComponent(
struct_or_mol=mg,
disable_callbacks=True,
id=f"{struct.composition.reduced_formula}_site_{index}",
scene_settings={"enableZoom": False, "defaultZoom": 0.6},
)._sub_layouts["struct"]
],
style={"width": "300px", "height": "300px"},
)
env = lse.coordination_environments[index]
co = all_ce.get_geometry_from_mp_symbol(env[0]["ce_symbol"])
name = co.name
if co.alternative_names:
name += f" (also known as {', '.join(co.alternative_names)})"
datalist.update(
{
"Environment": name,
"IUPAC Symbol": co.IUPAC_symbol_str,
get_tooltip(
"CSM",
"The continuous symmetry measure (CSM) describes the similarity to an "
"ideal coordination environment. It can be understood as a 'distance' to "
"a shape and ranges from 0 to 100 in which 0 corresponds to a "
"coordination environment that is exactly identical to the ideal one. A "
"CSM larger than 5.0 already indicates a relatively strong distortion of "
"the investigated coordination environment.",
): f"{env[0]['csm']:.2f}",
"Interactive View": view,
}
)
envs.append(get_data_list(datalist))
# TODO: switch to tiles?
envs_grouped = [envs[i : i + 2] for i in range(0, len(envs), 2)]
analysis_contents = []
for env_group in envs_grouped:
analysis_contents.append(
Columns([Column(e, size=6) for e in env_group])
)
if unknown_sites:
unknown_sites = html.Strong(
f"The following sites were not identified: {', '.join(unknown_sites)}. "
f"Please try changing the distance or angle cut-offs to identify these sites, "
f"or try an alternative algorithm such as LocalEnv."
)
else:
unknown_sites = html.Span()
return html.Div([html.Div(analysis_contents), html.Br(), unknown_sites])