import itertools
import numpy as np
import plotly.graph_objs as go
from dash.dependencies import Input, Output
from dash.exceptions import PreventUpdate
from dash_mp_components import CrystalToolkitScene
from pymatgen.core.periodic_table import Element
from pymatgen.electronic_structure.bandstructure import (
BandStructure,
BandStructureSymmLine,
)
from pymatgen.electronic_structure.core import Spin
from pymatgen.electronic_structure.dos import CompleteDos
from pymatgen.electronic_structure.plotter import BSPlotter
from pymatgen.ext.matproj import MPRester
from pymatgen.symmetry.bandstructure import HighSymmKpath
from crystal_toolkit.core.mpcomponent import MPComponent
from crystal_toolkit.core.panelcomponent import PanelComponent
from crystal_toolkit.core.scene import Convex, Cylinders, Lines, Scene, Spheres
from crystal_toolkit.helpers.layouts import (
Column,
Columns,
Label,
Loading,
MessageBody,
MessageContainer,
dcc,
get_data_list,
html,
)
# Author: Jason Munro
# Contact: jmunro@lbl.gov
# TODO: think about moving functionality to BSPlotter, DosPlotter
# TODO: remove access to private attributes of BSPlotter
[docs]class BandstructureAndDosComponent(MPComponent):
def __init__(
self,
mpid=None,
bandstructure_symm_line=None,
density_of_states=None,
id=None,
**kwargs,
):
# this is a compound component, can be fed by mpid or
# by the BandStructure itself
super().__init__(
id=id,
default_data={
"mpid": mpid,
"bandstructure_symm_line": bandstructure_symm_line,
"density_of_states": density_of_states,
},
**kwargs,
)
@property
def _sub_layouts(self):
# defaults
state = {"label-select": "lm", "dos-select": "ap"}
bs, dos = BandstructureAndDosComponent._get_bs_dos(self.initial_data["default"])
fig = BandstructureAndDosComponent.get_figure(bs, dos)
# Main plot
graph = Loading(
[
dcc.Graph(
figure=fig,
config={"displayModeBar": False},
responsive=True,
)
],
id=self.id("bsdos-div"),
)
# Brillouin zone
zone_scene = self.get_brillouin_zone_scene(bs)
zone = CrystalToolkitScene(data=zone_scene.to_json(), sceneSize="500px")
# Hide by default if not loaded by mpid, switching between k-paths
# on-the-fly only supported for bandstructures retrieved from MP
show_path_options = bool(self.initial_data["default"]["mpid"])
# Convention selection for band structure
convention = html.Div(
[
self.get_choice_input(
kwarg_label="path-convention",
state=state,
label="Path convention",
help_str="Convention to choose path in k-space",
options=[
{"label": "Latimer-Munro", "value": "lm"},
{"label": "Hinuma et al.", "value": "hin"},
{
"label": "Setyawan-Curtarolo",
"value": "sc",
},
],
)
],
style={"width": "200px"}
if show_path_options
else {"max-width": "200", "display": "none"},
id=self.id("path-container"),
)
# Equivalent labels across band structure conventions
label_select = html.Div(
[
self.get_choice_input(
kwarg_label="label-select",
state=state,
label="Label convention",
help_str="Convention to choose labels for path in k-space",
options=[
{"label": "Latimer-Munro", "value": "lm"},
{"label": "Hinuma et al.", "value": "hin"},
{
"label": "Setyawan-Curtarolo",
"value": "sc",
},
],
)
],
style={"width": "200px"}
if show_path_options
else {"width": "200px", "display": "none"},
id=self.id("label-container"),
)
# Density of states data selection
dos_select = self.get_choice_input(
kwarg_label="dos-select",
state=state,
label="Projection",
help_str="Choose projection",
options=[{"label": "Atom Projected", "value": "ap"}],
style={"width": "200px"},
)
table = get_data_list(self._get_data_list_dict(bs, dos))
return {
"graph": graph,
"convention": convention,
"dos-select": dos_select,
"label-select": label_select,
"zone": zone,
"table": table,
}
[docs] def layout(self):
sub_layouts = self._sub_layouts
return html.Div(
[
Columns([Column([sub_layouts["graph"]])]),
Columns(
[
Column(
[
sub_layouts["convention"],
sub_layouts["label-select"],
sub_layouts["dos-select"],
]
)
]
),
Columns(
[
Column([Label("Summary"), sub_layouts["table"]]),
Column([Label("Brillouin Zone"), sub_layouts["zone"]]),
]
),
]
)
@staticmethod
def _get_bs_dos(data):
data = data or {}
# this component can be loaded either from mpid or
# directly from BandStructureSymmLine or CompleteDos objects
# if mpid is supplied, this is preferred
mpid = data.get("mpid")
bandstructure_symm_line = data.get("bandstructure_symm_line")
density_of_states = data.get("density_of_states")
if not mpid and bandstructure_symm_line is None and density_of_states is None:
return None, None
if mpid:
with MPRester() as mpr:
try:
bandstructure_symm_line = mpr.get_bandstructure_by_material_id(mpid)
except Exception as exc:
print(exc)
bandstructure_symm_line = None
try:
density_of_states = mpr.get_dos_by_material_id(mpid)
except Exception as exc:
print(exc)
density_of_states = None
else:
if bandstructure_symm_line and isinstance(bandstructure_symm_line, dict):
bandstructure_symm_line = BandStructureSymmLine.from_dict(
bandstructure_symm_line
)
if density_of_states and isinstance(density_of_states, dict):
density_of_states = CompleteDos.from_dict(density_of_states)
return bandstructure_symm_line, density_of_states
[docs] @staticmethod
def get_ifermi_scene(bs: BandStructure) -> Scene:
pass
[docs] @staticmethod
def get_brillouin_zone_scene(bs: BandStructureSymmLine) -> Scene:
if not bs:
return Scene(name="brillouin_zone", contents=[])
# TODO: from BSPlotter, merge back into BSPlotter
# Brillouin zone
bz_lattice = bs.structure.lattice.reciprocal_lattice
bz = bz_lattice.get_wigner_seitz_cell()
lines = []
for iface in range(len(bz)): # pylint: disable=C0200
for line in itertools.combinations(bz[iface], 2):
for jface in range(len(bz)):
if (
iface < jface
and any(np.all(line[0] == x) for x in bz[jface])
and any(np.all(line[1] == x) for x in bz[jface])
):
lines += [list(line[0]), list(line[1])]
zone_lines = Lines(positions=lines)
zone_surface = Convex(positions=lines, opacity=0.05, color="#000000")
# - Strip latex math wrapping for labels
# TODO: add to string utils in pymatgen
str_replace = {
"$": "",
"\\mid": "|",
"\\Gamma": "Γ",
"\\Sigma": "Σ",
"GAMMA": "Γ",
"_1": "₁",
"_2": "₂",
"_3": "₃",
"_4": "₄",
"_{1}": "₁",
"_{2}": "₂",
"_{3}": "₃",
"_{4}": "₄",
"^{*}": "*",
}
labels = {}
for k in bs.kpoints:
if k.label:
label = k.label
for orig, new in str_replace.items():
label = label.replace(orig, new)
labels[label] = bz_lattice.get_cartesian_coords(k.frac_coords)
labels = [
Spheres(positions=[coords], tooltip=label, radius=0.03, color="#5EB1BF")
for label, coords in labels.items()
]
path = []
cylinder_pairs = []
for b in bs.branches:
start = bz_lattice.get_cartesian_coords(
bs.kpoints[b["start_index"]].frac_coords
)
end = bz_lattice.get_cartesian_coords(
bs.kpoints[b["end_index"]].frac_coords
)
path += [start, end]
cylinder_pairs += [[start, end]]
# path_lines = Lines(positions=path, color="#ff4b5c",)
path_lines = Cylinders(
positionPairs=cylinder_pairs, color="#5EB1BF", radius=0.01
)
ibz_region = Convex(positions=path, opacity=0.2, color="#5EB1BF")
contents = [zone_lines, zone_surface, path_lines, ibz_region, *labels]
cbm = bs.get_cbm()["kpoint"]
vbm = bs.get_vbm()["kpoint"]
if cbm and vbm:
if cbm.label:
cbm_label = cbm.label
for orig, new in str_replace.items():
cbm_label = cbm_label.replace(orig, new)
cbm_label = f"CBM at {cbm_label}"
else:
cbm_label = "CBM"
if cbm == vbm:
cbm_label = f"VBM and {cbm_label}"
cbm_coords = bz_lattice.get_cartesian_coords(cbm.frac_coords)
cbm = Spheres(
positions=[cbm_coords], tooltip=cbm_label, radius=0.05, color="#7E259B"
)
contents.append(cbm)
if cbm != vbm:
if vbm.label:
vbm_label = vbm.label
for orig, new in str_replace.items():
vbm_label = vbm_label.replace(orig, new)
vbm_label = f"VBM at {vbm_label}"
else:
vbm_label = "VBM"
vbm_coords = bz_lattice.get_cartesian_coords(vbm.frac_coords)
vbm = Spheres(
positions=[vbm_coords],
tooltip=vbm_label,
radius=0.05,
color="#7E259B",
)
contents.append(vbm)
return Scene(name="brillouin_zone", contents=contents)
[docs] @staticmethod
def get_bandstructure_traces(bs, path_convention, energy_window=(-6.0, 10.0)):
if path_convention == "lm":
bs = HighSymmKpath.get_continuous_path(bs)
bs_reg_plot = BSPlotter(bs)
bs_data = bs_reg_plot.bs_plot_data(split_branches=False)
bands = []
for band_num in range(bs.nb_bands):
for segment in bs_data["energy"][str(Spin.up)]:
if any(segment[band_num] <= energy_window[1]) and any(
segment[band_num] >= energy_window[0]
):
bands.append(band_num)
bs_traces = []
cbm = bs.get_cbm()
vbm = bs.get_vbm()
cbm_new = bs_data["cbm"]
vbm_new = bs_data["vbm"]
bar_loc = []
for d, dist_val in enumerate(bs_data["distances"]):
x_dat = dist_val
traces_for_segment = []
segment = bs_data["energy"][str(Spin.up)][d]
traces_for_segment += [
{
"x": x_dat,
"y": segment[band_num],
"mode": "lines",
"line": {"color": "#1f77b4"},
"hoverinfo": "skip",
"name": "spin ↑" if bs.is_spin_polarized else "Total",
"hovertemplate": "%{y:.2f} eV",
"showlegend": False,
"xaxis": "x",
"yaxis": "y",
}
for band_num in bands
]
if bs.is_spin_polarized:
traces_for_segment += [
{
"x": x_dat,
"y": [
bs_data["energy"][str(Spin.down)][d][i][j]
for j in range(len(bs_data["distances"][d]))
],
"mode": "lines",
"line": {"color": "#ff7f0e", "dash": "dot"},
"hoverinfo": "skip",
"showlegend": False,
"name": "spin ↓",
"hovertemplate": "%{y:.2f} eV",
"xaxis": "x",
"yaxis": "y",
}
for i in bands
]
bs_traces += traces_for_segment
bar_loc.append(dist_val[-1])
# - Strip latex math wrapping for labels
str_replace = {
"$": "",
"\\mid": "|",
"\\Gamma": "Γ",
"\\Sigma": "Σ",
"GAMMA": "Γ",
"_1": "₁",
"_2": "₂",
"_3": "₃",
"_4": "₄",
"_{1}": "₁",
"_{2}": "₂",
"_{3}": "₃",
"_{4}": "₄",
"^{*}": "*",
}
for entry_num in range(len(bs_data["ticks"]["label"])):
for key in str_replace:
if key in bs_data["ticks"]["label"][entry_num]:
bs_data["ticks"]["label"][entry_num] = bs_data["ticks"]["label"][
entry_num
].replace(key, str_replace[key])
# Vertical lines for disjointed segments
vert_traces = [
{
"x": [x_point, x_point],
"y": energy_window,
"mode": "lines",
"marker": {"color": "white"},
"hoverinfo": "skip",
"showlegend": False,
"xaxis": "x",
"yaxis": "y",
}
for x_point in bar_loc
]
bs_traces += vert_traces
# Dots for cbm and vbm
dot_traces = [
{
"x": [x_point],
"y": [y_point],
"mode": "markers",
"marker": {
"color": "#7E259B",
"size": 16,
"line": {"color": "white", "width": 2},
},
"showlegend": False,
"hoverinfo": "text",
"name": "",
"hovertemplate": f"CBM: k = {list(cbm['kpoint'].frac_coords)}, {cbm['energy']} eV",
"xaxis": "x",
"yaxis": "y",
}
for (x_point, y_point) in set(cbm_new)
] + [
{
"x": [x_point],
"y": [y_point],
"mode": "markers",
"marker": {
"color": "#7E259B",
"size": 16,
"line": {"color": "white", "width": 2},
},
"showlegend": False,
"hoverinfo": "text",
"name": "",
"hovertemplate": f"VBM: k = {list(vbm['kpoint'].frac_coords)}, {vbm['energy']} eV",
"xaxis": "x",
"yaxis": "y",
}
for (x_point, y_point) in set(vbm_new)
]
bs_traces += dot_traces
return bs_traces, bs_data
[docs] @staticmethod
def get_dos_traces(dos, dos_select, energy_window=(-6.0, 10.0)):
dostraces = []
dos_max = np.abs(dos.energies - dos.efermi - energy_window[1]).argmin()
dos_min = np.abs(dos.energies - dos.efermi - energy_window[0]).argmin()
# TODO: pymatgen should have a property here
spin_polarized = len(dos.densities) == 2
if spin_polarized:
# Add second spin data if available
trace_tdos = {
"x": -1.0 * dos.densities[Spin.down][dos_min:dos_max],
"y": dos.energies[dos_min:dos_max] - dos.efermi,
"mode": "lines",
"name": "Total DOS (spin ↓)",
"line": go.scatter.Line(color="#444444", dash="dot"),
"fill": "tozerox",
"fillcolor": "#C4C4C4",
"xaxis": "x2",
"yaxis": "y2",
}
dostraces.append(trace_tdos)
tdos_label = "Total DOS (spin ↑)"
else:
tdos_label = "Total DOS"
# Total DOS
trace_tdos = {
"x": dos.densities[Spin.up][dos_min:dos_max],
"y": dos.energies[dos_min:dos_max] - dos.efermi,
"mode": "lines",
"name": tdos_label,
"line": go.scatter.Line(color="#444444"),
"fill": "tozerox",
"fillcolor": "#C4C4C4",
"legendgroup": "spinup",
"xaxis": "x2",
"yaxis": "y2",
}
dostraces.append(trace_tdos)
if dos_select == "tot":
proj_data = {}
elif dos_select == "ap":
proj_data = dos.get_element_dos()
elif dos_select == "op":
proj_data = dos.get_spd_dos()
elif "orb" in dos_select:
proj_data = dos.get_element_spd_dos(Element(dos_select.replace("orb", "")))
else:
raise PreventUpdate
# Projected DOS
count = 0
colors = [
"#d62728", # brick red
"#2ca02c", # cooked asparagus green
"#17becf", # blue-teal
"#bcbd22", # curry yellow-green
"#9467bd", # muted purple
"#8c564b", # chestnut brown
"#e377c2", # raspberry yogurt pink
]
for label in proj_data:
if spin_polarized:
trace = {
"x": -1.0 * proj_data[label].densities[Spin.down][dos_min:dos_max],
"y": dos.energies[dos_min:dos_max] - dos.efermi,
"mode": "lines",
"name": f"{label} (spin ↓)",
"line": dict(width=3, color=colors[count], dash="dot"),
"xaxis": "x2",
"yaxis": "y2",
}
dostraces.append(trace)
spin_up_label = f"{label} (spin ↑)"
else:
spin_up_label = str(label)
trace = {
"x": proj_data[label].densities[Spin.up][dos_min:dos_max],
"y": dos.energies[dos_min:dos_max] - dos.efermi,
"mode": "lines",
"name": spin_up_label,
"line": dict(width=2, color=colors[count]),
"xaxis": "x2",
"yaxis": "y2",
}
dostraces.append(trace)
count += 1
return dostraces
@staticmethod
def _get_data_list_dict(bs, dos):
return {
"Band Gap": "... eV",
"Direct Gap": "...",
"CBM": "...",
"VBM": "...",
"Spin Polarization": "...",
}
[docs] def generate_callbacks(self, app, cache):
@app.callback(
Output(self.id("bsdos-div"), "children"), [Input(self.id("traces"), "data")]
)
def update_graph(traces):
if traces == "error":
body = MessageBody(
dcc.Markdown(
"Band structure and density of states not available for this selection."
)
)
search_error = MessageContainer([body], kind="warning")
return search_error
if traces is None:
raise PreventUpdate
figure = self.get_figure(bs, dos, path_convention, dos_select)
return [
dcc.Graph(
figure=figure, config={"displayModeBar": False}, responsive=True
)
]
@app.callback(
[
Output(self.id("label-select"), "value"),
Output(self.id("label-container"), "style"),
],
[
Input(self.id("mpid"), "data"),
Input(self.id("path-convention"), "value"),
],
)
def update_label_select(mpid, path_convention):
if not mpid:
raise PreventUpdate
else:
label_value = path_convention
label_style = {"max-width": "200"}
return [label_value, label_style]
@app.callback(
[
Output(self.id("dos-select"), "options"),
Output(self.id("path-convention"), "options"),
Output(self.id("path-container"), "style"),
],
[Input(self.id("elements"), "data"), Input(self.id("mpid"), "data")],
)
def update_select(elements, mpid):
if elements is None:
raise PreventUpdate
elif not mpid:
dos_options = (
[{"label": "Element Projected", "value": "ap"}]
+ [{"label": "Orbital Projected - Total", "value": "op"}]
+ [
{
"label": f"Orbital Projected - {ele_label}",
"value": f"orb{ele_label}",
}
for ele_label in elements
]
)
path_options = [{"label": "N/A", "value": "sc"}]
path_style = {"max-width": "200", "display": "none"}
return [dos_options, path_options, path_style]
else:
dos_options = (
[{"label": "Element Projected", "value": "ap"}]
+ [{"label": "Orbital Projected - Total", "value": "op"}]
+ [
{
"label": f"Orbital Projected - {ele_label}",
"value": f"orb{ele_label}",
}
for ele_label in elements
]
)
path_options = [
{"label": "Setyawan-Curtarolo", "value": "sc"},
{"label": "Latimer-Munro", "value": "lm"},
{"label": "Hinuma et al.", "value": "hin"},
]
path_style = {"max-width": "200"}
return [dos_options, path_options, path_style]
@app.callback(
[Output(self.id("traces"), "data"), Output(self.id("elements"), "data")],
[
Input(self.id(), "data"),
Input(self.id("path-convention"), "value"),
Input(self.id("dos-select"), "value"),
Input(self.id("label-select"), "value"),
],
)
def bs_dos_data(data, path_convention, dos_select, label_select):
# Obtain bands to plot over and generate traces for bs data:
energy_window = (-6.0, 10.0)
traces = []
if bandstructure_symm_line:
bs_traces = get_bandstructure_traces(
bsml, path_convention, energy_window=energy_window
)
traces.append(bs_traces)
if density_of_states:
dostraces = get_dos_traces(
density_of_states, energy_window=energy_window, spin_polarized=...
)
traces.append(dostraces)
# traces = [bs_traces, dostraces, bs_data]
return (traces, elements)
[docs]class BandstructureAndDosPanelComponent(PanelComponent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.bs = BandstructureAndDosComponent()
self.bs.attach_from(self, this_store_name="mpid")
@property
def title(self):
return "Band Structure and Density of States"
@property
def description(self):
return "Display the band structure and density of states for this structure \
if it has been calculated by the Materials Project."
@property
def initial_contents(self):
return html.Div(
[
super().initial_contents,
html.Div([self.bs.standard_layout], style={"display": "none"}),
]
)
[docs] def update_contents(self, new_store_contents, *args):
return self.bs.standard_layout