import math
import numpy as np
import plotly.graph_objs as go
from dash import callback_context, dcc, html
from dash.dependencies import Input, Output
from dash.exceptions import PreventUpdate
from pymatgen.analysis.diffraction.tem import TEMCalculator
from pymatgen.analysis.diffraction.xrd import WAVELENGTHS, XRDCalculator
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from scipy.special import wofz
from crystal_toolkit.core.mpcomponent import MPComponent
from crystal_toolkit.helpers.layouts import Box, Column, Columns, Loading
# Scherrer equation: Langford, J. Il, and A. J. C. Wilson. "Scherrer after sixty years:
# a survey and some new results in the determination of crystallite size." Journal of
# applied crystallography 11.2 (1978): 102-113.
# https://doi.org/10.1107/S0021889878012844
#    def __init__(self, symprec: float = None, voltage: float = 200,
#                beam_direction: Tuple[int, int, int] = (0, 0, 1), camera_length: int = 160,
#                debye_waller_factors: Dict[str, float] = None, cs: float = 1) -> None:
# Author: Matthew McDermott
# Contact: mcdermott@lbl.gov
[docs]class TEMDiffractionComponent(MPComponent):
    def __init__(self, *args, initial_structure=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.create_store("structure", initial_data=initial_structure)
[docs]    def layout(self):
        voltage = self.get_numerical_input(
            kwarg_label="voltage",
            default=200,
            label="Voltage / kV",
            help_str="The incident wavelength with which to generate the diffraction pattern, "
            "typically corresponding to a TEM microscope’s voltage.",
        )
        beam_direction = self.get_numerical_input(
            kwarg_label="beam_direction",
            default=[0, 0, 1],
            label="Beam Direction",
            help_str="The direction of the electron beam fired onto the sample.",
            shape=(3,),
            is_int=True,
        )
        # TODO: add additional kwargs for TemCalculator, or switch to an alternative solution
        return Columns(
            [
                Column([Box(Loading(id=self.id("tem-plot")))], size=8),
                Column(
                    [voltage, html.Br(), beam_direction],
                    size=4,
                ),
            ],
        ) 
[docs]    def generate_callbacks(self, app, cache):
        @app.callback(
            Output(self.id("tem-plot"), "children"),
            [
                Input(self.id("structure"), "data"),
                Input(self.get_all_kwargs_id(), "value"),
            ],
        )
        def generate_diffraction_pattern(structure, *args):
            structure = self.from_data(structure)
            kwargs = self.reconstruct_kwargs_from_state()
            calculator = TEMCalculator(**kwargs)
            print("kwargs", kwargs)
            return dcc.Graph(
                figure=calculator.get_plot_2d(structure),
                responsive=False,
                config={"displayModeBar": False, "displaylogo": False},
            )  
[docs]class XRayDiffractionComponent(MPComponent):
    # TODO: add pole figures for a given single peak for help quantifying texture
    def __init__(self, *args, initial_structure=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.create_store("structure", initial_data=initial_structure)
    # Default XRD plot style settings
    default_xrd_plot_style = dict(
        xaxis={
            "title": "2𝜃 / º",
            "anchor": "y",
            "nticks": 8,
            "showgrid": True,
            "showline": True,
            "side": "bottom",
            "tickfont": {"size": 16.0},
            "ticks": "inside",
            "titlefont": {"size": 16.0},
            "type": "linear",
            "zeroline": False,
        },
        yaxis={
            "title": "Intensity / arb. units",
            "anchor": "x",
            "nticks": 7,
            "showgrid": True,
            "showline": True,
            "side": "left",
            "tickfont": {"size": 16.0},
            "ticks": "inside",
            "titlefont": {"size": 16.0},
            "type": "linear",
            "zeroline": False,
        },
        autosize=True,
        hovermode="x",
        height=225,
        showlegend=False,
        paper_bgcolor="rgba(0,0,0,0)",
        plot_bgcolor="rgba(0,0,0,0)",
        margin=dict(l=60, b=50, t=50, pad=0, r=30),
        title="X-ray Diffraction Pattern",
        template="simple_white",
    )
    empty_plot_style = {
        "xaxis": {"visible": False},
        "yaxis": {"visible": False},
        "paper_bgcolor": "rgba(0,0,0,0)",
        "plot_bgcolor": "rgba(0,0,0,0)",
    }
[docs]    @staticmethod
    def G(x, c, alpha):
        """Return c-centered Gaussian line shape at x with HWHM alpha"""
        return (
            np.sqrt(np.log(2) / np.pi)
            / alpha
            * np.exp(-(((x - c) / alpha) ** 2) * np.log(2))
        ) 
[docs]    @staticmethod
    def L(x, c, gamma):
        """Return c-centered Lorentzian line shape at x with HWHM gamma"""
        return gamma / (np.pi * ((x - c) ** 2 + gamma**2)) 
[docs]    @staticmethod
    def V(x, c, alphagamma):
        """Return the c-centered Voigt line shape at x, scaled to match HWHM of Gaussian and Lorentzian profiles."""
        alpha = 0.61065 * alphagamma
        gamma = 0.61065 * alphagamma
        sigma = alpha / np.sqrt(2 * np.log(2))
        return np.real(wofz(((x - c) + 1j * gamma) / (sigma * np.sqrt(2)))) / (
            sigma * np.sqrt(2 * np.pi)
        ) 
[docs]    @staticmethod
    def twotheta_to_q(twotheta, xray_wavelength):
        """
        Convert twotheta to Q.
        :param twotheta: in degrees
        :param xray_wavelength: in Ångstroms
        :return:
        """
        # thanks @rwoodsrobinson
        return (4 * np.pi / xray_wavelength) * np.sin(np.deg2rad(twotheta) / 2) 
[docs]    @staticmethod
    def grain_to_hwhm(tau, two_theta, K=0.9, wavelength="CuKa"):
        """
        :param tau: grain size in nm
        :param two_theta: angle (in 2-theta)
        :param K: shape factor (default 0.9)
        :param wavelength: wavelength radiation in nm
        :return: half-width half-max (alpha or gamma), for line profile
        """
        if isinstance(wavelength, str):
            wavelength = WAVELENGTHS[wavelength]
        # factor of 0.1 to convert wavelength to nm
        return (
            0.5 * K * 0.1 * wavelength / (tau * abs(np.cos(two_theta / 2)))
        )  # Scherrer equation for half-width half max 
    @property
    def _sub_layouts(self):
        state = {
            "peak_profile": "G",
            "shape_factor": 0.94,
            "rad_source": "CuKa",
            "x_axis": "twotheta",
            "crystallite_size": 0.1,
        }
        # Main plot
        graph = Loading(
            [
                dcc.Graph(
                    figure=go.Figure(layout=XRayDiffractionComponent.empty_plot_style),
                    id=self.id("xrd-plot"),
                    config={
                        "displayModeBar": False,  # or "hover",
                        "plotGlPixelRatio": 2,
                        "displaylogo": False,
                        # "modeBarButtons": [["toImage"]],  # to only add an image download button
                        "toImageButtonOptions": {
                            "format": "png",
                            "filename": "xrd",
                            "scale": 4,
                            "width": 600,
                            "height": 400,
                        },
                        "editable": True,
                    },
                    responsive=True,
                    animate=False,
                )
            ]
        )
        # Radiation source selector
        rad_source = self.get_choice_input(
            kwarg_label="rad_source",
            state=state,
            label="Radiation source",
            help_str="This defines the wavelength of the incident X-ray radiation.",
            options=[
                {
                    "label": f'{name.replace("a", "α").replace("b", "β")} ({wavelength:.3f} Å)',
                    "value": name,
                }
                for name, wavelength in WAVELENGTHS.items()
            ],
            style={"width": "10rem"},
        )
        # Shape factor input
        shape_factor = self.get_numerical_input(
            kwarg_label="shape_factor",
            state=state,
            label="Shape Factor",
            help_str="""The peak profile determines what distribute characterizes the broadening of
an XRD pattern. Two extremes are Gaussian distributions, which are useful for peaks with more rounded tops
(typically due to strain broadening) and Lorentzian distributions, which are useful for peaks with
sharper top (typically due to size distributions and dislocations). In reality, peak shapes usually
follow a Voigt distribution, which is a convolution of Gaussian and Lorentzian peak shapes, with the
contribution to both Gaussian and Lorentzian components sample and instrument dependent. Here, both
contributions are equally weighted if Voigt is chosen.""",
        )
        # Peak profile selector (Gaussian, Lorentzian, Voigt)
        peak_profile = self.get_choice_input(
            kwarg_label="peak_profile",
            state=state,
            label="Peak Profile",
            help_str="""The shape factor K, also known as the “Scherrer constant” is a
dimensionless quantity to obtain an actual particle size from an apparent particle size
determined from XRD. The discrepancy is because the shape of an individual crystallite
will change the resulting diffraction broadening. Commonly, a value of 0.94 for isotropic
crystals in a spherical shape is used. However, in practice K can vary from 0.62 to 2.08.""",
            options=[
                {"label": "Gaussian", "value": "G"},
                {"label": "Lorentzian", "value": "L"},
                {"label": "Voigt", "value": "V"},
            ],
            style={"width": "10rem"},
        )
        # 2Theta or Q for x-axis
        x_axis_choice = html.Div(
            [
                self.get_choice_input(
                    kwarg_label="x_axis",
                    state=state,
                    label="Choice of 𝑥 axis",
                    help_str="Can choose between 2𝜃 or Q, where Q is the magnitude of the reciprocal lattice and "
                    "independent of radiation source.",  # TODO: improve
                    options=[
                        {"label": "2𝜃", "value": "twotheta"},
                        {"label": "Q", "value": "Q"},
                    ],
                )
            ],
            style={
                "display": "none"
            },  # TODO: this is buggy! let's fix it before we share
        )
        # Crystallite size selector (via Scherrer Equation)
        crystallite_size = self.get_slider_input(
            kwarg_label="crystallite_size",
            label="Scherrer crystallite size / nm",
            state=state,
            help_str="Simulate a real diffraction pattern by applying Scherrer broadening, which estimates the "
            "full width at half maximum (FWHM) resulting from a finite, rather than infinite, crystallite "
            "size.",
            domain=[-1, 2],
            step=0.01,
            isLogScale=True,
        )
        static_image = self.get_figure_placeholder("xrd-plot")
        return {
            "x_axis": x_axis_choice,
            "graph": graph,
            "rad_source": rad_source,
            "peak_profile": peak_profile,
            "shape_factor": shape_factor,
            "crystallite_size": crystallite_size,
            "static_image": static_image,
        }
[docs]    def layout(self, static_image=False):
        """
        Get the standard XRD diffraction pattern layout.
        :param static_image: If True, will show a static image instead of an interactive graph.
        :return:
        """
        sub_layouts = self._sub_layouts
        if static_image:
            inner = sub_layouts["static_image"]
        else:
            inner = sub_layouts["graph"]
        return Columns(
            [
                Column(
                    [Box([inner], style={"height": "480px"})],
                    size=8,
                    style={"height": "600px"},
                ),
                Column(
                    [
                        sub_layouts["x_axis"],
                        sub_layouts["rad_source"],
                        sub_layouts["shape_factor"],
                        sub_layouts["peak_profile"],
                        sub_layouts["crystallite_size"],
                    ],
                    size=4,
                ),
            ]
        ) 
[docs]    def generate_callbacks(self, app, cache):
        @app.callback(
            Output(self.id("xrd-plot"), "figure"),
            [
                Input(self.id(), "data"),
                Input(self.get_kwarg_id("crystallite_size"), "value"),
                Input(self.get_kwarg_id("rad_source"), "value"),
                Input(self.get_kwarg_id("peak_profile"), "value"),
                Input(self.get_kwarg_id("shape_factor"), "value"),
                Input(self.get_kwarg_id("x_axis"), "value"),
            ],
        )
        def update_graph(data, logsize, rad_source, peak_profile, K, x_axis):
            if not data:
                raise PreventUpdate
            kwargs = self.reconstruct_kwargs_from_state(callback_context.inputs)
            if not kwargs:
                raise PreventUpdate
            peak_profile = kwargs["peak_profile"]
            K = kwargs["shape_factor"]
            rad_source = kwargs["rad_source"]
            logsize = float(kwargs["crystallite_size"])
            x_axis = kwargs["x_axis"]
            grain_size = 10**logsize
            x_peak = data["x"]
            y_peak = data["y"]
            d_hkls = data["d_hkls"]
            hkls = data["hkls"]
            plot = self.get_figure(
                peak_profile,
                K,
                rad_source,
                grain_size,
                x_peak,
                y_peak,
                d_hkls,
                hkls,
                x_axis,
            )
            return plot
        @app.callback(
            Output(self.id(), "data"),
            [
                Input(self.id("structure"), "data"),
                Input(self.get_kwarg_id("rad_source"), "value"),
            ],
        )
        def pattern_from_struct(struct, rad_source):
            if struct is None or not rad_source:
                raise PreventUpdate
            struct = self.from_data(struct)
            rad_source = self.reconstruct_kwarg_from_state(
                callback_context.inputs, "rad_source"
            )
            sga = SpacegroupAnalyzer(struct)
            struct = (
                sga.get_conventional_standard_structure()
            )  # always get conventional structure
            xrdc = XRDCalculator(
                wavelength=WAVELENGTHS[rad_source], symprec=0, debye_waller_factors=None
            )
            data = xrdc.get_pattern(struct, two_theta_range=None)
            return data.as_dict()  
        # @app.callback(
        #     Output(self.id("static-image"), "src"),
        #     [Input(self.id("xrd-plot"), "figure")]
        # )
        # def update_static_image(data):
        #
        #     scope = PlotlyScope()
        #     output = scope.transform(data, format="png", width=600, height=400, scale=4)
        #     image = b64encode(output).decode('ascii')
        #
        #     return f"data:image/png;base64,{image}"