Source code for crystal_toolkit.core.legend

from __future__ import annotations

import os
import warnings
from collections import defaultdict
from itertools import chain
from typing import Any

import numpy as np
from import get_cmap
from monty.json import MSONable
from monty.serialization import loadfn
from palettable.colorbrewer.qualitative import Set1_9
from pymatgen.analysis.molecule_structure_comparator import CovalentRadius
from pymatgen.core.periodic_table import Element, Specie
from pymatgen.core.structure import Molecule, Site, SiteCollection
from pymatgen.util.string import unicodeify_species
from sklearn.preprocessing import LabelEncoder
from webcolors import html5_parse_legacy_color, html5_serialize_simple_color

# element colors forked from pymatgen
module_dir = os.path.dirname(os.path.abspath(__file__))
EL_COLORS = loadfn(os.path.join(module_dir, "ElementColorSchemes.yaml"))

[docs]class Legend(MSONable): """ Help generate a legend (colors and radii) for a Structure or Molecule such that colors and radii can be displayed for the appropriate species. Note that species themselves have a color (for example, Oxygen is typically red), but that we might also want to color-code by site properties (for example, magnetic moment), thus this class has to take into account both the species present and its context (the specific site the species is at) to correctly generate the legend. """ default_color_scheme = "Jmol" default_color = [0, 0, 0] default_radius = 1.0 fallback_radius = 0.5 uniform_radius = 0.5 def __init__( self, site_collection: SiteCollection | Site, color_scheme: str = "Jmol", radius_scheme: str = "uniform", cmap: str = "coolwarm", cmap_range: tuple[float, float] | None = None, ): """ Create a legend for a given SiteCollection to choose how to display colors and radii for the given sites and the species on those sites. If a site has a "display_color" or "display_radius" site property defined, this can be used to manually override the displayed colors and radii respectively. Args: site_collection: SiteCollection or, for convenience, a single site can be provided and this will be converted into a SiteCollection color_scheme: choose how to color-code species, one of "Jmol", "VESTA", "accessible" or a scalar site property (e.g. magnetic moment) or a categorical/string site property (e.g. Wyckoff label) radius_scheme: choose the radius for a species, one of "atomic", "specified_or_average_ionic", "covalent", "van_der_waals", "atomic_calculated", "uniform" cmap: only used if color_mode is set to a scalar site property, defines the matplotlib color map to use, by default is blue-white-red for negative to positive values cmap_range: only used if color_mode is set to a scalar site property, defines the minimum and maximum values of the color scape """ if isinstance(site_collection, Site): site_collection = Molecule.from_sites([site_collection]) site_prop_types = self.analyze_site_props(site_collection) self.allowed_color_schemes = ( ["VESTA", "Jmol", "accessible"] + site_prop_types.get("scalar", []) + site_prop_types.get("categorical", []) ) self.allowed_radius_schemes = ( "atomic", "specified_or_average_ionic", "covalent", "van_der_waals", "atomic_calculated", "uniform", ) if color_scheme not in self.allowed_color_schemes: warnings.warn( f"Color scheme {color_scheme} not available, " f"falling back to {self.default_color_scheme}." ) color_scheme = self.default_color_scheme # if color-coding by a scalar site property, determine minimum and # maximum values for color scheme, will default to be symmetric # about zero if color_scheme in site_prop_types.get("scalar", []) and not cmap_range: props = np.array( [ p for p in site_collection.site_properties[color_scheme] if p is not None ] ) prop_max = max([abs(min(props)), max(props)]) prop_min = -prop_max cmap_range = (prop_min, prop_max) el_colors = EL_COLORS.copy() el_colors.update( self.generate_accessible_color_scheme_on_the_fly(site_collection) ) self.categorical_colors = self.generate_categorical_color_scheme_on_the_fly( site_collection, site_prop_types ) self.el_colors = el_colors self.site_prop_types = site_prop_types self.site_collection = site_collection self.color_scheme = color_scheme self.radius_scheme = radius_scheme self.cmap = cmap self.cmap_range = cmap_range
[docs] @staticmethod def generate_accessible_color_scheme_on_the_fly( site_collection: SiteCollection, ) -> dict[str, dict[str, tuple[int, int, int]]]: """ e.g. for a color scheme more appropriate for people with color blindness Args: site_collection: SiteCollection Returns: A dictionary in similar format to EL_COLORS """ color_scheme = {} all_species = set( chain.from_iterable(comp for comp in site_collection.species_and_occu) ) all_elements = sorted(sp.as_dict()["element"] for sp in all_species) # thanks to palette = [ (0, 0, 0), # 0, black (230, 159, 0), # 1, orange (86, 180, 233), # 2, sky blue (0, 158, 115), #  3, bluish green (240, 228, 66), # 4, yellow (0, 114, 178), # 5, blue (213, 94, 0), # 6, vermilion (204, 121, 167), # 7, reddish purple (255, 255, 255), #  8, white ] # similar to CPK, mapping element to palette index preferred_colors = { "O": 6, "N": 2, "C": 0, "H": 8, "F": 3, "Cl": 3, "Fe": 1, "Br": 7, "I": 7, "P": 1, "S": 4, } if len(set(all_elements)) > len(palette): warnings.warn( "Too many distinct types of site to use an accessible color scheme, " "some sites will be given the default color." ) preferred_elements_present = [ el for el in all_elements if el in preferred_colors ] colors_assigned = [] for el in preferred_elements_present: if preferred_colors[el] not in colors_assigned: color_scheme[el] = palette[preferred_colors[el]] colors_assigned.append(preferred_colors[el]) remaining_elements = [el for el in all_elements if el not in color_scheme] remaining_palette = [ c for idx, c in enumerate(palette) if idx not in colors_assigned ] for el in remaining_elements: if remaining_palette: color_scheme[el] = remaining_palette.pop() return {"accessible": color_scheme}
[docs] @staticmethod def generate_categorical_color_scheme_on_the_fly( site_collection: SiteCollection, site_prop_types ) -> dict[str, dict[str, tuple[int, int, int]]]: """ e.g. for Wykcoff Args: site_collection: SiteCollection Returns: A dictionary in similar format to EL_COLORS """ color_scheme = {} palette = Set1_9.colors for site_prop_name in site_prop_types.get("categorical", []): props = np.array(site_collection.site_properties[site_prop_name]) props[props is None] = "None" le = LabelEncoder() transformed_props = le.transform(props) # if we have more categories than available colors, # arbitrarily group some categories together if len(set(props)) > len(palette): warnings.warn( "Too many categories for a complete categorical color scheme." ) transformed_props = [ p if p < len(palette) else -1 for p in transformed_props ] colors = {name: palette[p] for name, p in zip(props, transformed_props)} color_scheme[site_prop_name] = colors return color_scheme
[docs] def get_color(self, sp: Specie | Element, site: Site | None = None) -> str: """ Get a color to render a specific species. Optionally, you can provide a site for context, since ... Args: sp: Specie or Element site: Site Returns: Color """ # allow manual override by user if site and "display_color" in color =["display_color"] # TODO: next two lines due to change in API, will be removed if isinstance(color, list) and isinstance(color[0], str): color = color[0] if isinstance(color, list): return html5_serialize_simple_color(color) else: return html5_serialize_simple_color(html5_parse_legacy_color(color)) if self.color_scheme in ("VESTA", "Jmol", "accessible"): el = sp.as_dict()["element"] color = self.el_colors[self.color_scheme].get( el, self.el_colors["Extras"].get(el, self.default_color) ) elif self.color_scheme in self.site_prop_types.get("scalar", []): if not site: raise ValueError( "Requires a site for context to get the " "appropriate site property." ) prop =[self.color_scheme] if prop: cmap = get_cmap(self.cmap) # normalize in [0, 1] range, as expected by cmap prop_min = self.cmap_range[0] prop_max = self.cmap_range[1] prop_normed = (prop - prop_min) / (prop_max - prop_min) color = [int(c * 255) for c in cmap(prop_normed)[0:3]] else: # fallback if site prop is None color = self.default_color elif self.color_scheme in self.site_prop_types.get("categorical", []): if not site: raise ValueError( "Requires a site for context to get the " "appropriate site property." ) prop =[self.color_scheme] color = self.categorical_colors[self.color_scheme].get( prop, self.default_color ) else: raise ValueError( f"Unknown color for {sp} and color scheme {self.color_scheme}." ) return html5_serialize_simple_color(color)
[docs] def get_radius(self, sp: Specie | Element, site: Site | None = None) -> float: # allow manual override by user if site and "display_radius" in return["display_radius"] if self.radius_scheme not in self.allowed_radius_schemes: raise ValueError( f"Unknown radius scheme {self.radius_scheme}, " f"choose from: {self.allowed_radius_schemes}." ) radius = None if self.radius_scheme == "uniform": radius = self.uniform_radius elif self.radius_scheme == "atomic": radius = float(sp.atomic_radius) elif ( self.radius_scheme == "specified_or_average_ionic" and isinstance(sp, Specie) and sp.oxi_state ): radius = float(sp.ionic_radius) elif self.radius_scheme == "specified_or_average_ionic": radius = float(sp.average_ionic_radius) elif self.radius_scheme == "covalent": el = str(getattr(sp, "element", sp)) radius = float(CovalentRadius.radius[el]) elif self.radius_scheme == "van_der_waals": radius = float(sp.van_der_waals_radius) elif self.radius_scheme == "atomic_calculated": radius = float(sp.atomic_radius_calculated) if (not radius) or (not isinstance(radius, float)): warnings.warn( f"Radius unknown for {sp} and strategy {self.radius_scheme}, " "setting to 0.5." ) radius = self.fallback_radius return radius
[docs] @staticmethod def analyze_site_props(site_collection: SiteCollection) -> dict[str, list[str]]: """ Returns: A dictionary with keys "scalar", "matrix", "vector", "categorical" and values of a list of site property names corresponding to each type """ # (implicitly assumes all site props for a given key are same type) site_prop_names = defaultdict(list) for name, props in site_collection.site_properties.items(): if isinstance(props[0], float) or isinstance(props[0], int): site_prop_names["scalar"].append(name) elif isinstance(props[0], list) and len(props[0]) == 3: if isinstance(props[0][0], list) and len(props[0][0]) == 3: site_prop_names["matrix"].append(name) else: site_prop_names["vector"].append(name) elif isinstance(props[0], str): site_prop_names["categorical"].append(name) return dict(site_prop_names)
[docs] @staticmethod def get_species_str(sp: Specie | Element) -> str: """ Args: sp: Specie or Element Returns: string representation """ # TODO: add roman numerals for oxidation state for ease of readability # and then move this to pymatgen string utils ... return unicodeify_species(str(sp))
[docs] def get_legend(self) -> dict[str, Any]: # decide what we want the labels to be if self.color_scheme in ("Jmol", "VESTA", "accessible"): label = lambda site, sp: self.get_species_str(sp) elif self.color_scheme in self.site_prop_types.get("scalar", {}): label = lambda site, sp: f"{[self.color_scheme]:.2f}" elif self.color_scheme in self.site_prop_types.get("categorical", {}): label = lambda site, sp: f"{[self.color_scheme]}" else: raise ValueError(f"Color scheme {self.color_scheme} not known.") legend = defaultdict(list) # first get all our colors for different species for site in self.site_collection: for sp in site.species: legend[self.get_color(sp, site)].append(label(site, sp)) legend = {k: ", ".join(sorted(list(set(v)))) for k, v in legend.items()} color_options = [] for site_prop_type in ("scalar", "categorical"): if site_prop_type in self.site_prop_types: for prop in self.site_prop_types[site_prop_type]: color_options.append(prop) return { "composition": self.site_collection.composition.as_dict(), "colors": legend, "available_color_schemes": color_options, }