import dash
import plotly.graph_objs as go
from dash import dash_table, dcc, html
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate
from pymatgen.analysis.phase_diagram import PDEntry, PDPlotter, PhaseDiagram
from pymatgen.core.composition import Composition
from pymatgen.ext.matproj import MPRester
from crystal_toolkit.core.mpcomponent import MPComponent
from crystal_toolkit.core.panelcomponent import PanelComponent
from crystal_toolkit.helpers.layouts import (
Button,
Column,
Columns,
MessageBody,
MessageContainer,
)
# Author: Matthew McDermott
# Contact: mcdermott@lbl.gov
[docs]class PhaseDiagramComponent(MPComponent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.create_store("mpid")
self.create_store("chemsys-internal")
self.create_store("chemsys-external")
self.create_store("figure")
self.create_store("entries")
# Default plot layouts for Binary (2), Ternary (3), Quaternary (4) phase diagrams
default_binary_plot_style = dict(
xaxis={
"title": "Fraction",
"anchor": "y",
"mirror": "ticks",
"nticks": 8,
"showgrid": False,
"showline": True,
"side": "bottom",
"tickfont": {"size": 16.0},
"ticks": "inside",
"titlefont": {"color": "#000000", "size": 24.0},
"type": "linear",
"zeroline": False,
},
yaxis={
"title": "Formation energy (eV/fu)",
"anchor": "x",
"mirror": "ticks",
"nticks": 7,
"showgrid": False,
"showline": True,
"side": "left",
"tickfont": {"size": 16.0},
"ticks": "inside",
"titlefont": {"color": "#000000", "size": 24.0},
"type": "linear",
"zeroline": False,
},
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
height=550,
width=500,
hovermode="closest",
showlegend=True,
legend=dict(
orientation="h",
traceorder="reversed",
x=1.0,
y=1.08,
xanchor="right",
tracegroupgap=5,
),
margin=dict(l=80, b=70, t=10, r=20),
)
default_ternary_plot_style = dict(
xaxis=dict(
title=None,
autorange=True,
showgrid=False,
zeroline=False,
showline=False,
ticks="",
showticklabels=False,
),
yaxis=dict(
title=None,
autorange=True,
showgrid=False,
zeroline=False,
showline=False,
ticks="",
showticklabels=False,
),
autosize=True,
height=450,
width=500,
hovermode="closest",
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
margin=dict(b=30, l=30, pad=0, t=0, r=20),
showlegend=True,
legend=dict(
orientation="h",
traceorder="reversed",
x=1.0,
y=1.08,
xanchor="right",
tracegroupgap=5,
),
)
default_3d_axis = dict(
title=None,
visible=False,
autorange=True,
showgrid=False,
zeroline=False,
showline=False,
ticks="",
showaxeslabels=False,
showticklabels=False,
showspikes=False,
)
default_quaternary_plot_style = dict(
autosize=True,
height=450,
hovermode="closest",
margin=dict(b=30, l=30, pad=0, t=0, r=20),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
showlegend=True,
legend=dict(
orientation="h",
traceorder="reversed",
x=1.0,
y=1.08,
xanchor="right",
tracegroupgap=5,
),
scene=dict(xaxis=default_3d_axis, yaxis=default_3d_axis, zaxis=default_3d_axis),
)
empty_plot_style = {
"xaxis": {"visible": False},
"yaxis": {"visible": False},
"paper_bgcolor": "rgba(0,0,0,0)",
"plot_bgcolor": "rgba(0,0,0,0)",
}
colorscale = [
[0.0, "#008d00"],
[0.1111111111111111, "#4b9f3f"],
[0.2222222222222222, "#73b255"],
[0.3333333333333333, "#97c65b"],
[0.4444444444444444, "#b9db53"],
[0.5555555555555556, "#ffdcdf"],
[0.6666666666666666, "#ffb8bf"],
[0.7777777777777778, "#fd92a0"],
[0.8888888888888888, "#f46b86"],
[1.0, "#e24377"],
]
default_table_params = [
{"col": "Material ID", "edit": False},
{"col": "Formula", "edit": True},
{"col": "Formation Energy (eV/atom)", "edit": True},
{"col": "Energy Above Hull (eV/atom)", "edit": False},
{"col": "Predicted Stable?", "edit": False},
]
empty_row = {
"Material ID": None,
"Formula": "INSERT",
"Formation Energy (eV/atom)": "INSERT",
"Energy Above Hull (eV/atom)": None,
"Predicted Stable": None,
}
[docs] def create_markers(self, plotter, pd):
x_list = []
y_list = []
z_list = []
text = []
energy_list = []
dim = pd.dim
for coord, entry in plotter.pd_plot_data[1].items():
energy = round(pd.get_form_energy_per_atom(entry), 3)
energy_list.append(energy)
mpid = entry.attribute
formula = entry.composition.reduced_formula
clean_formula = self.clean_formula(formula)
x_list.append(coord[0])
y_list.append(coord[1])
if dim == 4:
z_list.append(coord[2])
text.append(f"{clean_formula} ({mpid})<br> {str(energy)} eV")
if dim == 2 or dim == 3:
marker_plot = go.Scatter(
x=x_list,
y=y_list,
mode="markers",
name="Stable",
marker=dict(
color=energy_list,
size=11,
colorscale=self.colorscale,
line=dict(width=2, color="#000000"),
),
hoverinfo="text",
hoverlabel=dict(font=dict(size=14)),
showlegend=True,
hovertext=text,
)
if dim == 4:
marker_plot = go.Scatter3d(
x=x_list,
y=y_list,
z=z_list,
mode="markers",
name="Stable",
marker=dict(
color=energy_list,
size=8,
colorscale=self.colorscale,
line=dict(width=2, color="#000000"),
),
hoverinfo="text",
hoverlabel=dict(font=dict(size=14)),
hovertext=text,
showlegend=True,
)
return marker_plot
[docs] def create_unstable_markers(self, plotter, pd):
x_list = []
y_list = []
z_list = []
text_list = []
dim = pd.dim
for (unstable_entry, unstable_coord) in plotter.pd_plot_data[2].items():
x_list.append(unstable_coord[0])
y_list.append(unstable_coord[1])
if dim == 4:
z_list.append(unstable_coord[2])
mpid = unstable_entry.attribute
formula = list(unstable_entry.composition.reduced_formula)
e_above_hull = round(pd.get_e_above_hull(unstable_entry), 3)
clean_formula = self.clean_formula(formula)
energy = round(pd.get_form_energy_per_atom(unstable_entry), 3)
text_list.append(
f"{clean_formula} ({mpid})<br>" f"{energy} eV (+{e_above_hull} eV)"
)
if dim == 2 or dim == 3:
unstable_marker_plot = go.Scatter(
x=x_list,
y=y_list,
mode="markers",
hoverinfo="text",
hovertext=text_list,
visible="legendonly",
name="Unstable",
marker=dict(color="#ff0000", size=12, symbol="x"),
)
elif dim == 4:
unstable_marker_plot = go.Scatter3d(
x=x_list,
y=y_list,
z=z_list,
mode="markers",
hoverinfo="text",
hovertext=text_list,
visible="legendonly",
name="Unstable",
marker=dict(color="#ff0000", size=4, symbol="x"),
)
return unstable_marker_plot
[docs] @staticmethod
def create_table_content(pd):
data = []
for entry in pd.all_entries:
try:
mpid = entry.entry_id
except Exception:
mpid = entry.attribute # accounting for custom entry
try:
data.append(
{
"Material ID": mpid,
"Formula": entry.name,
"Formation Energy (eV/atom)": round(
pd.get_form_energy_per_atom(entry), 3
),
"Energy Above Hull (eV/atom)": round(
pd.get_e_above_hull(entry), 3
),
"Predicted Stable?": (
"Yes" if pd.get_e_above_hull(entry) == 0 else "No"
),
}
)
except Exception:
data.append({})
return data
[docs] @staticmethod
def ternary_plot(plot_data):
"""
Return a ternary phase diagram in a two-dimensional plot.
Args:
plot_data: plot data from PDPlotter
Returns: go.Figure
"""
go.Scatterternary(
{
"mode": "markers",
"a": ..., # list_of_a_comp
"b": ...,
"c": ...,
"text": ...,
"marker": {
"symbol": 100,
"color": ...,
"size": ...,
"line": {"width": 2},
},
}
)
go.Scatterternary({"mode": "lines", "a": ..., "b": ..., "c": ..., "line": ...})
go.Layout(
{
"title": "Ternary Scatter Plot",
"ternary": {
"sum": 1,
"aaxis": {
"title": "X",
"min": 0.01,
"linewidth": 2,
"ticks": "outside",
},
"baxis": {
"title": "W",
"min": 0.01,
"linewidth": 2,
"ticks": "outside",
},
"caxis": {
"title": "S",
"min": 0.01,
"linewidth": 2,
"ticks": "outside",
},
},
"showlegend": False,
}
)
return go.Figure()
@property
def _sub_layouts(self):
graph = html.Div(
[
dcc.Graph(
figure=go.Figure(layout=PhaseDiagramComponent.empty_plot_style),
id=self.id("graph"),
config={"displayModeBar": False, "displaylogo": False},
)
],
id=self.id("pd-div"),
)
table = html.Div(
[
html.Div(
dash_table.DataTable(
id=self.id("entry-table"),
columns=(
[
{
"id": p["col"],
"name": p["col"],
"editable": p["edit"],
}
for p in self.default_table_params
]
),
style_table={
"maxHeight": "450px",
"overflowY": "auto",
"border": "thin lightgrey solid",
},
# n_fixed_rows=1,
sort_action="native",
editable=True,
row_deletable=True,
style_header={
"backgroundColor": "rgb(230, 249, 255)",
"fontWeight": "bold",
},
style_cell={
"fontFamily": "IBM Plex Sans",
"textAlign": "centered",
"whiteSpace": "normal",
},
css=[
{
"selector": ".dash-cell div.dash-cell-value",
"rule": "display: inline; white-space: inherit; overflow: inherit; "
"text-overflow: inherit;",
}
],
style_cell_conditional=[
{"if": {"column_id": "Material ID"}, "width": "20%"},
{"if": {"column_id": "Formula"}, "width": "20%"},
],
)
),
Button(
"Add Custom Entry",
id=self.id("editing-rows-button"),
kind="primary",
n_clicks=0,
),
html.P("Enter composition and formation energy per atom."),
]
)
return {"graph": graph, "table": table}
[docs] def layout(self):
return html.Div(
[
Columns(
[
Column(self._sub_layouts["graph"]),
Column(self._sub_layouts["table"]),
],
centered=True,
)
]
)
[docs] def generate_callbacks(self, app, cache):
@app.callback(
Output(self.id("pd-div"), "children"), [Input(self.id("figure"), "data")]
)
def update_graph(figure):
if figure is None:
raise PreventUpdate
elif figure == "error":
search_error = (
MessageContainer(
[
MessageBody(
dcc.Markdown(
"Plotting is only available for phase diagrams containing 2-4 components."
)
)
],
kind="warning",
),
)
return search_error
else:
plot = [
dcc.Graph(
figure=figure,
config={"displayModeBar": False, "displaylogo": False},
)
]
return plot
@app.callback(Output(self.id("figure"), "data"), [Input(self.id(), "data")])
def make_figure(pd):
if pd is None:
raise PreventUpdate
pd = self.from_data(pd)
dim = pd.dim
if dim not in [2, 3, 4]:
return "error"
plotter = PDPlotter(pd)
data = []
for line in plotter.pd_plot_data[0]:
if dim == 2 or dim == 3:
data.append(
go.Scatter(
x=list(line[0]),
y=list(line[1]), # create all phase diagram lines
mode="lines",
hoverinfo="none",
line={
"color": "rgba (0, 0, 0, 1)",
"dash": "solid",
"width": 3.0,
},
showlegend=False,
)
)
elif dim == 4:
data.append(
go.Scatter3d(
x=list(line[0]),
y=list(line[1]),
z=list(line[2]),
mode="lines",
hoverinfo="none",
line={
"color": "rgba (0, 0, 0, 1)",
"dash": "solid",
"width": 3.0,
},
showlegend=False,
)
)
data.append(self.create_unstable_markers(plotter, pd))
data.append(self.create_markers(plotter, pd))
fig = go.Figure(data=data)
fig.layout = self.figure_layout(plotter, pd)
return fig
@app.callback(Output(self.id(), "data"), [Input(self.id("entries"), "data")])
def create_pd_object(entries):
if entries is None or not entries:
raise PreventUpdate
entries = self.from_data(entries)
return PhaseDiagram(entries)
@app.callback(
Output(self.id("entries"), "data"),
[Input(self.id("entry-table"), "derived_virtual_data")],
)
def update_entries_store(rows):
if rows is None:
raise PreventUpdate
entries = []
for row in rows:
try:
comp = Composition(row["Formula"])
energy = row["Formation Energy (eV/atom)"]
if row["Material ID"] is None:
attribute = "Custom Entry"
else:
attribute = row["Material ID"]
# create new entry object containing mpid as attribute (to combine with custom entries)
entry = PDEntry(
comp, float(energy) * comp.num_atoms, attribute=attribute
)
entries.append(entry)
except Exception:
continue
if not entries:
raise PreventUpdate
return entries
@app.callback(
Output(self.id("entry-table"), "data"),
[
Input(self.id("chemsys-internal"), "data"),
Input(self.id(), "modified_timestamp"),
Input(self.id("editing-rows-button"), "n_clicks"),
],
[State(self.id(), "data"), State(self.id("entry-table"), "data")],
)
def create_table(chemsys, pd_time, n_clicks, pd, rows):
ctx = dash.callback_context
if ctx is None or not ctx.triggered or chemsys is None:
raise PreventUpdate
trigger = ctx.triggered[0]
# PD update trigger
if trigger["prop_id"] == f"{self.id()}.modified_timestamp":
table_content = self.create_table_content(self.from_data(pd))
return table_content
elif trigger["prop_id"] == f"{self.id('editing-rows-button')}.n_clicks":
if n_clicks > 0 and rows:
rows.append(self.empty_row)
return rows
with MPRester() as mpr:
entries = mpr.get_entries_in_chemsys(chemsys)
pd = PhaseDiagram(entries)
table_content = self.create_table_content(pd)
return table_content
@app.callback(
Output(self.id("chemsys-internal"), "data"),
[
Input(self.id("mpid"), "data"),
Input(self.id("chemsys-external"), "data"),
],
)
def get_chemsys_from_mpid_or_chemsys(mpid, chemsys_external: str):
"""
:param mpid: mpid
:param chemsys_external: chemsys, e.g. "Co-O"
:return: chemsys
"""
ctx = dash.callback_context
if ctx is None or not ctx.triggered:
raise PreventUpdate
trigger = ctx.triggered[0]
if trigger["value"] is None:
raise PreventUpdate
chemsys = None
# get entries by mpid
if trigger["prop_id"] == f"{self.id('mpid')}.data":
with MPRester() as mpr:
entry = mpr.get_entry_by_material_id(mpid)
chemsys = entry.composition.chemical_system
# get entries by chemsys
if trigger["prop_id"] == f"{self.id('chemsys-external')}.data":
chemsys = chemsys_external
return chemsys
[docs]class PhaseDiagramPanelComponent(PanelComponent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pd_component = PhaseDiagramComponent()
self.pd_component.attach_from(self, this_store_name="struct")
@property
def title(self):
return "Phase Diagram"
@property
def description(self):
return (
"Display the compositional phase diagram for the"
" chemical system containing this structure (between 2-4 species)."
)
[docs] def update_contents(self, new_store_contents, *args):
return self.pd_component.layout