import math
import numpy as np
import plotly.graph_objs as go
from dash import callback_context, dcc, html
from dash.dependencies import Input, Output
from dash.exceptions import PreventUpdate
from pymatgen.analysis.diffraction.tem import TEMCalculator
from pymatgen.analysis.diffraction.xrd import WAVELENGTHS, XRDCalculator
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from scipy.special import wofz
from crystal_toolkit.core.mpcomponent import MPComponent
from crystal_toolkit.helpers.layouts import Box, Column, Columns, Loading
# Scherrer equation: Langford, J. Il, and A. J. C. Wilson. "Scherrer after sixty years:
# a survey and some new results in the determination of crystallite size." Journal of
# applied crystallography 11.2 (1978): 102-113.
# https://doi.org/10.1107/S0021889878012844
# def __init__(self, symprec: float = None, voltage: float = 200,
# beam_direction: Tuple[int, int, int] = (0, 0, 1), camera_length: int = 160,
# debye_waller_factors: Dict[str, float] = None, cs: float = 1) -> None:
# Author: Matthew McDermott
# Contact: mcdermott@lbl.gov
[docs]class TEMDiffractionComponent(MPComponent):
def __init__(self, *args, initial_structure=None, **kwargs):
super().__init__(*args, **kwargs)
self.create_store("structure", initial_data=initial_structure)
[docs] def layout(self):
voltage = self.get_numerical_input(
kwarg_label="voltage",
default=200,
label="Voltage / kV",
help_str="The incident wavelength with which to generate the diffraction pattern, "
"typically corresponding to a TEM microscope’s voltage.",
)
beam_direction = self.get_numerical_input(
kwarg_label="beam_direction",
default=[0, 0, 1],
label="Beam Direction",
help_str="The direction of the electron beam fired onto the sample.",
shape=(3,),
is_int=True,
)
# TODO: add additional kwargs for TemCalculator, or switch to an alternative solution
return Columns(
[
Column([Box(Loading(id=self.id("tem-plot")))], size=8),
Column(
[voltage, html.Br(), beam_direction],
size=4,
),
],
)
[docs] def generate_callbacks(self, app, cache):
@app.callback(
Output(self.id("tem-plot"), "children"),
[
Input(self.id("structure"), "data"),
Input(self.get_all_kwargs_id(), "value"),
],
)
def generate_diffraction_pattern(structure, *args):
structure = self.from_data(structure)
kwargs = self.reconstruct_kwargs_from_state()
calculator = TEMCalculator(**kwargs)
print("kwargs", kwargs)
return dcc.Graph(
figure=calculator.get_plot_2d(structure),
responsive=False,
config={"displayModeBar": False, "displaylogo": False},
)
[docs]class XRayDiffractionComponent(MPComponent):
# TODO: add pole figures for a given single peak for help quantifying texture
def __init__(self, *args, initial_structure=None, **kwargs):
super().__init__(*args, **kwargs)
self.create_store("structure", initial_data=initial_structure)
# Default XRD plot style settings
default_xrd_plot_style = dict(
xaxis={
"title": "2𝜃 / º",
"anchor": "y",
"nticks": 8,
"showgrid": True,
"showline": True,
"side": "bottom",
"tickfont": {"size": 16.0},
"ticks": "inside",
"titlefont": {"size": 16.0},
"type": "linear",
"zeroline": False,
},
yaxis={
"title": "Intensity / arb. units",
"anchor": "x",
"nticks": 7,
"showgrid": True,
"showline": True,
"side": "left",
"tickfont": {"size": 16.0},
"ticks": "inside",
"titlefont": {"size": 16.0},
"type": "linear",
"zeroline": False,
},
autosize=True,
hovermode="x",
height=225,
showlegend=False,
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
margin=dict(l=60, b=50, t=50, pad=0, r=30),
title="X-ray Diffraction Pattern",
template="simple_white",
)
empty_plot_style = {
"xaxis": {"visible": False},
"yaxis": {"visible": False},
"paper_bgcolor": "rgba(0,0,0,0)",
"plot_bgcolor": "rgba(0,0,0,0)",
}
[docs] @staticmethod
def G(x, c, alpha):
"""Return c-centered Gaussian line shape at x with HWHM alpha"""
return (
np.sqrt(np.log(2) / np.pi)
/ alpha
* np.exp(-(((x - c) / alpha) ** 2) * np.log(2))
)
[docs] @staticmethod
def L(x, c, gamma):
"""Return c-centered Lorentzian line shape at x with HWHM gamma"""
return gamma / (np.pi * ((x - c) ** 2 + gamma**2))
[docs] @staticmethod
def V(x, c, alphagamma):
"""Return the c-centered Voigt line shape at x, scaled to match HWHM of Gaussian and Lorentzian profiles."""
alpha = 0.61065 * alphagamma
gamma = 0.61065 * alphagamma
sigma = alpha / np.sqrt(2 * np.log(2))
return np.real(wofz(((x - c) + 1j * gamma) / (sigma * np.sqrt(2)))) / (
sigma * np.sqrt(2 * np.pi)
)
[docs] @staticmethod
def twotheta_to_q(twotheta, xray_wavelength):
"""
Convert twotheta to Q.
:param twotheta: in degrees
:param xray_wavelength: in Ångstroms
:return:
"""
# thanks @rwoodsrobinson
return (4 * np.pi / xray_wavelength) * np.sin(np.deg2rad(twotheta) / 2)
[docs] @staticmethod
def grain_to_hwhm(tau, two_theta, K=0.9, wavelength="CuKa"):
"""
:param tau: grain size in nm
:param two_theta: angle (in 2-theta)
:param K: shape factor (default 0.9)
:param wavelength: wavelength radiation in nm
:return: half-width half-max (alpha or gamma), for line profile
"""
if isinstance(wavelength, str):
wavelength = WAVELENGTHS[wavelength]
# factor of 0.1 to convert wavelength to nm
return (
0.5 * K * 0.1 * wavelength / (tau * abs(np.cos(two_theta / 2)))
) # Scherrer equation for half-width half max
@property
def _sub_layouts(self):
state = {
"peak_profile": "G",
"shape_factor": 0.94,
"rad_source": "CuKa",
"x_axis": "twotheta",
"crystallite_size": 0.1,
}
# Main plot
graph = Loading(
[
dcc.Graph(
figure=go.Figure(layout=XRayDiffractionComponent.empty_plot_style),
id=self.id("xrd-plot"),
config={
"displayModeBar": False, # or "hover",
"plotGlPixelRatio": 2,
"displaylogo": False,
# "modeBarButtons": [["toImage"]], # to only add an image download button
"toImageButtonOptions": {
"format": "png",
"filename": "xrd",
"scale": 4,
"width": 600,
"height": 400,
},
"editable": True,
},
responsive=True,
animate=False,
)
]
)
# Radiation source selector
rad_source = self.get_choice_input(
kwarg_label="rad_source",
state=state,
label="Radiation source",
help_str="This defines the wavelength of the incident X-ray radiation.",
options=[
{
"label": f'{name.replace("a", "α").replace("b", "β")} ({wavelength:.3f} Å)',
"value": name,
}
for name, wavelength in WAVELENGTHS.items()
],
style={"width": "10rem"},
)
# Shape factor input
shape_factor = self.get_numerical_input(
kwarg_label="shape_factor",
state=state,
label="Shape Factor",
help_str="""The peak profile determines what distribute characterizes the broadening of
an XRD pattern. Two extremes are Gaussian distributions, which are useful for peaks with more rounded tops
(typically due to strain broadening) and Lorentzian distributions, which are useful for peaks with
sharper top (typically due to size distributions and dislocations). In reality, peak shapes usually
follow a Voigt distribution, which is a convolution of Gaussian and Lorentzian peak shapes, with the
contribution to both Gaussian and Lorentzian components sample and instrument dependent. Here, both
contributions are equally weighted if Voigt is chosen.""",
)
# Peak profile selector (Gaussian, Lorentzian, Voigt)
peak_profile = self.get_choice_input(
kwarg_label="peak_profile",
state=state,
label="Peak Profile",
help_str="""The shape factor K, also known as the “Scherrer constant” is a
dimensionless quantity to obtain an actual particle size from an apparent particle size
determined from XRD. The discrepancy is because the shape of an individual crystallite
will change the resulting diffraction broadening. Commonly, a value of 0.94 for isotropic
crystals in a spherical shape is used. However, in practice K can vary from 0.62 to 2.08.""",
options=[
{"label": "Gaussian", "value": "G"},
{"label": "Lorentzian", "value": "L"},
{"label": "Voigt", "value": "V"},
],
style={"width": "10rem"},
)
# 2Theta or Q for x-axis
x_axis_choice = html.Div(
[
self.get_choice_input(
kwarg_label="x_axis",
state=state,
label="Choice of 𝑥 axis",
help_str="Can choose between 2𝜃 or Q, where Q is the magnitude of the reciprocal lattice and "
"independent of radiation source.", # TODO: improve
options=[
{"label": "2𝜃", "value": "twotheta"},
{"label": "Q", "value": "Q"},
],
)
],
style={
"display": "none"
}, # TODO: this is buggy! let's fix it before we share
)
# Crystallite size selector (via Scherrer Equation)
crystallite_size = self.get_slider_input(
kwarg_label="crystallite_size",
label="Scherrer crystallite size / nm",
state=state,
help_str="Simulate a real diffraction pattern by applying Scherrer broadening, which estimates the "
"full width at half maximum (FWHM) resulting from a finite, rather than infinite, crystallite "
"size.",
domain=[-1, 2],
step=0.01,
isLogScale=True,
)
static_image = self.get_figure_placeholder("xrd-plot")
return {
"x_axis": x_axis_choice,
"graph": graph,
"rad_source": rad_source,
"peak_profile": peak_profile,
"shape_factor": shape_factor,
"crystallite_size": crystallite_size,
"static_image": static_image,
}
[docs] def layout(self, static_image=False):
"""
Get the standard XRD diffraction pattern layout.
:param static_image: If True, will show a static image instead of an interactive graph.
:return:
"""
sub_layouts = self._sub_layouts
if static_image:
inner = sub_layouts["static_image"]
else:
inner = sub_layouts["graph"]
return Columns(
[
Column(
[Box([inner], style={"height": "480px"})],
size=8,
style={"height": "600px"},
),
Column(
[
sub_layouts["x_axis"],
sub_layouts["rad_source"],
sub_layouts["shape_factor"],
sub_layouts["peak_profile"],
sub_layouts["crystallite_size"],
],
size=4,
),
]
)
[docs] def generate_callbacks(self, app, cache):
@app.callback(
Output(self.id("xrd-plot"), "figure"),
[
Input(self.id(), "data"),
Input(self.get_kwarg_id("crystallite_size"), "value"),
Input(self.get_kwarg_id("rad_source"), "value"),
Input(self.get_kwarg_id("peak_profile"), "value"),
Input(self.get_kwarg_id("shape_factor"), "value"),
Input(self.get_kwarg_id("x_axis"), "value"),
],
)
def update_graph(data, logsize, rad_source, peak_profile, K, x_axis):
if not data:
raise PreventUpdate
kwargs = self.reconstruct_kwargs_from_state(callback_context.inputs)
if not kwargs:
raise PreventUpdate
peak_profile = kwargs["peak_profile"]
K = kwargs["shape_factor"]
rad_source = kwargs["rad_source"]
logsize = float(kwargs["crystallite_size"])
x_axis = kwargs["x_axis"]
grain_size = 10**logsize
x_peak = data["x"]
y_peak = data["y"]
d_hkls = data["d_hkls"]
hkls = data["hkls"]
plot = self.get_figure(
peak_profile,
K,
rad_source,
grain_size,
x_peak,
y_peak,
d_hkls,
hkls,
x_axis,
)
return plot
@app.callback(
Output(self.id(), "data"),
[
Input(self.id("structure"), "data"),
Input(self.get_kwarg_id("rad_source"), "value"),
],
)
def pattern_from_struct(struct, rad_source):
if struct is None or not rad_source:
raise PreventUpdate
struct = self.from_data(struct)
rad_source = self.reconstruct_kwarg_from_state(
callback_context.inputs, "rad_source"
)
sga = SpacegroupAnalyzer(struct)
struct = (
sga.get_conventional_standard_structure()
) # always get conventional structure
xrdc = XRDCalculator(
wavelength=WAVELENGTHS[rad_source], symprec=0, debye_waller_factors=None
)
data = xrdc.get_pattern(struct, two_theta_range=None)
return data.as_dict()
# @app.callback(
# Output(self.id("static-image"), "src"),
# [Input(self.id("xrd-plot"), "figure")]
# )
# def update_static_image(data):
#
# scope = PlotlyScope()
# output = scope.transform(data, format="png", width=600, height=400, scale=4)
# image = b64encode(output).decode('ascii')
#
# return f"data:image/png;base64,{image}"