"""Module containing plotting utility functionality."""
from __future__ import annotations
from collections.abc import Iterable
from math import ceil
from math import log
from types import MappingProxyType
from typing import TYPE_CHECKING
from warnings import warn
import numpy as np
import xarray as xr
from matplotlib.ticker import Locator
from pyglotaran_extras.inspect.utils import pretty_format_numerical_iterable
from pyglotaran_extras.io.utils import result_dataset_mapping
if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Hashable
from collections.abc import Mapping
from typing import Literal
from cycler import Cycler
from matplotlib.axis import Axis
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes
from pyglotaran_extras.types import BuiltinSubPlotLabelFormatFunctionKey
from pyglotaran_extras.types import ResultLike
from pyglotaran_extras.types import SubPlotLabelCoord
[docs]
class PlotDuplicationWarning(UserWarning):
"""Warning given when there are more subplots than datapoints."""
[docs]
def select_irf_dispersion_center_by_index(
irf_dispersion: xr.DataArray, main_irf_nr: int = 0
) -> xr.DataArray | float:
"""Select a subset of the IRF dispersion data where ``irf_nr==main_irf_nr``.
Parameters
----------
irf_dispersion : xr.DataArray
Data Variable from a result dataset which contains the IRF dispersion data.
main_irf_nr : int
Index of the main ``irf`` component when using an ``irf``
parametrized with multiple peaks. Defaults to 0.
Returns
-------
xr.DataArray | float
DataArray only containing the IRF dispersion data for the main IRF.
Raises
------
ValueError
If ``irf_nr`` is not in the coordinates
"""
if "irf_nr" in irf_dispersion.sizes:
if main_irf_nr >= irf_dispersion.sizes["irf_nr"]:
msg = (
f"The value main_irf_nr={main_irf_nr!r} is not a valid value for irf_nr, needs "
f"to be smaller than {irf_dispersion.sizes['irf_nr']}."
)
raise ValueError(msg)
irf_dispersion = irf_dispersion.sel(irf_nr=main_irf_nr)
if irf_dispersion.size == 1:
irf_dispersion = irf_dispersion.item()
return irf_dispersion
[docs]
def maximum_coordinate_range(
result: ResultLike, coord_name: str = "spectral"
) -> tuple[float, float]:
"""Calculate the minimal and maximal values of a coordinate across datasets.
Parameters
----------
result : ResultLike
Data structure which can be converted to a mapping.
coord_name : str
Name of the coordinate to calculate the maximal range for.
Returns
-------
tuple[float, float]
Minimal and maximal values across datasets
See Also
--------
plot_fit_overview
"""
result_map = result_dataset_mapping(result)
minima = []
maxima = []
for dataset in result_map.values():
coord = dataset.coords[coord_name].to_numpy()
minima.append(coord.min())
maxima.append(coord.max())
return min(minima), max(maxima)
[docs]
def select_plot_wavelengths(
result: ResultLike,
axes_shape: tuple[int, int] = (4, 4),
wavelength_range: tuple[float, float] | None = None,
equidistant_wavelengths: bool = True,
) -> Iterable[float]:
"""Select wavelengths to be used in ``plot_fit_overview`` from a result.
Parameters
----------
result : ResultLike
Data structure which can be converted to a mapping of datasets.
axes_shape : tuple[int, int]
Shape of the plot grid (N, M). Defaults to (4, 4).
wavelength_range : tuple[float, float] | None
Tuple of minimum and maximum values to calculate the the wavelengths
used for plotting. If not provided the values will be tetermined over all datasets.
Defaults to None.
equidistant_wavelengths : bool
Whether or not wavelengths should be selected based on equidistant values
or equidistant indices (only supported for a single dataset).
Since in general multiple datasets will have. Defaults to True.
Returns
-------
Iterable[float]
Wavelength which should be used for each subplot by ``plot_fit_overview``.
See Also
--------
maximum_coordinate_range
"""
result_map = result_dataset_mapping(result)
nr_of_plots = np.prod(axes_shape)
if wavelength_range is None:
wavelength_range = maximum_coordinate_range(result_map)
if equidistant_wavelengths:
return np.linspace(*wavelength_range, num=nr_of_plots)
first_dataset = next(iter(result_map.keys()))
if len(result_map) > 1:
warn(
UserWarning(
"Calculating plot wavelengths is only supported, for a single dataset."
f"The dataset {first_dataset!r}, will be used to calculate the selected "
"wavelengths.To mute this warning call "
f"'{select_plot_wavelengths.__name__}' with only one dataset."
),
stacklevel=2,
)
spectral_coords = result_map[first_dataset].coords["spectral"]
spectral_coords = spectral_coords[
(wavelength_range[0] <= spectral_coords) & (spectral_coords <= wavelength_range[1])
]
spectral_indices = np.linspace(0, len(spectral_coords) - 1, num=nr_of_plots, dtype=np.int64)
return spectral_coords[spectral_indices].to_numpy()
[docs]
def shift_time_axis_by_irf_location(
plot_data: xr.DataArray, irf_location: float | None, *, _internal_call: bool = False
) -> xr.DataArray:
"""Shift ``plot_data`` 'time' axis by the position of the main ``irf``.
Parameters
----------
plot_data : xr.DataArray
Data to plot.
irf_location : float | None
Location of the ``irf``, if the value is None the original ``plot_data`` will be returned.
_internal_call : bool
This indicates internal use stripping away user help and silently skipping execution.
Defaults to False.
Returns
-------
xr.DataArray
``plot_data`` with the time axis shifted by the position of the main ``irf``.
Raises
------
ValueError
If ``plot_data`` does not have a time axis.
See Also
--------
extract_irf_location
"""
if irf_location is None or ("time" not in plot_data.coords and _internal_call is True):
return plot_data
if "time" not in plot_data.coords:
msg = "plot_data need to have a 'time' axis."
raise ValueError(msg)
times_shifted = plot_data.coords["time"] - irf_location
return plot_data.assign_coords(time=times_shifted)
[docs]
def get_shifted_traces(
res: xr.Dataset, center_λ: float | None = None, main_irf_nr: int = 0
) -> xr.DataArray:
"""Shift traces by the position of the main ``irf``.
Parameters
----------
res : xr.Dataset
Result dataset from a pyglotaran optimization.
center_λ : float | None
Center wavelength (λ in nm). Defaults to None.
main_irf_nr : int
Index of the main ``irf`` component when using an ``irf``
parametrized with multiple peaks. Defaults to 0.
Returns
-------
xr.DataArray
Traces shifted by the ``irf``s location, to align the at 0.
Raises
------
ValueError
If no known concentration was found in the result.
"""
if "species_concentration" in res:
traces = res.species_concentration
elif "species_associated_concentrations" in res:
traces = res.species_associated_concentrations
else:
msg = f"No concentrations in result:\n{res}"
raise ValueError(msg)
irf_location = extract_irf_location(res, center_λ, main_irf_nr)
return shift_time_axis_by_irf_location(traces, irf_location)
[docs]
def ensure_axes_array(axes: Axis | Axes) -> Axes:
"""Ensure that axes have flatten method even if it is a single axis.
Parameters
----------
axes : Axis | Axes
Axis or Axes to convert for API consistency.
Returns
-------
Axes
Numpy ndarray of axes.
"""
# We can't use `Axis` in isinstance so we check for the np.ndarray attribute of `Axes`
if hasattr(axes, "flatten") is False:
axes = np.array([axes])
return axes
[docs]
def add_cycler_if_not_none(axis: Axis | Axes, cycler: Cycler | None) -> None:
"""Add cycler to and axis if it is not None.
This is a convenience function that allow to opt out of using
a cycler, which is needed to run a plotting function in a loop
where the cycler is controlled from the outside.
Parameters
----------
axis : Axis | Axes
Axis to plot on.
cycler : Cycler | None
Plot style cycler to use.
"""
if cycler is not None:
axis = ensure_axes_array(axis)
for ax in axis.flatten():
ax.set_prop_cycle(cycler)
[docs]
def abs_max(
data: xr.DataArray, *, result_dims: Hashable | Iterable[Hashable] = ()
) -> xr.DataArray:
"""Calculate the absolute maximum values of ``data`` along all dims except ``result_dims``.
Parameters
----------
data : xr.DataArray
Data for which the absolute maximum should be calculated.
result_dims : Hashable | Iterable[Hashable]
Dimensions of ``data`` which should be preserved and part of the resulting DataArray.
Defaults to () which results in using the absolute maximum of all values.
Returns
-------
xr.DataArray
Absolute maximum values of ``data`` with dimensions ``result_dims``.
"""
if not isinstance(result_dims, Iterable):
result_dims = (result_dims,)
reduce_dims = (dim for dim in data.dims if dim not in result_dims)
return np.abs(data).max(dim=reduce_dims)
[docs]
def calculate_ticks_in_units_of_pi(
values: np.ndarray | xr.DataArray, *, step_size: float = 0.5
) -> tuple[Iterable[float], Iterable[str]]:
r"""Calculate tick values and labels in units of Pi.
Parameters
----------
values : np.ndarray | xr.DataArray
Values which the ticks should be calculated for.
step_size : float
Step size of the ticks in units of pi. Defaults to 0.5
Returns
-------
tuple[Iterable[float], Iterable[str]]
Tick values and tick labels
See Also
--------
pyglotaran_extras.plotting.plot_doas.plot_doas
Examples
--------
If you have a case study that uses a ``damped-oscillation`` megacomplex you can plot the
``damped_oscillation_phase`` with y-tick in units of Pi by the following code given that the
dataset is saved under ``dataset.nc``.
.. code-block:: python
import matplotlib.pyplot as plt
from glotaran.io import load_dataset
from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi
dataset = load_dataset("dataset.nc")
fig, ax = plt.subplots(1, 1)
damped_oscillation_phase = dataset["damped_oscillation_phase"].sel(
damped_oscillation=["osc1"]
)
damped_oscillation_phase.plot.line(x="spectral", ax=ax)
ax.set_yticks(
*calculate_ticks_in_units_of_pi(damped_oscillation_phase), rotation="horizontal"
)
"""
values = np.array(values)
int_values_over_pi = np.round(values / np.pi / step_size)
tick_labels = np.arange(int_values_over_pi.min(), int_values_over_pi.max() + 1) * step_size
return tick_labels * np.pi, (
str(val) for val in pretty_format_numerical_iterable(tick_labels, decimal_places=1)
)
[docs]
def not_single_element_dims(data_array: xr.DataArray) -> list[Hashable]:
"""Names of dimensions in ``data`` which don't have a size equal to one.
This helper function is for example used to determine if a data only have a single trace,
since this requires different plotting code (e.g. ``data_array.plot.line(x="time")``).
Parameters
----------
data_array : xr.DataArray
DataArray to check if it has only a single dimension.
Returns
-------
list[Hashable]
Names of dimensions in ``data`` which don't have a size equal to one.
"""
return [dim for dim, values in data_array.coords.items() if values.size != 1]
[docs]
class MinorSymLogLocator(Locator):
"""Dynamically find minor tick positions based on major ticks for a symlog scaling.
Ref.: https://stackoverflow.com/a/45696768
"""
def __init__(self, linthresh: float, nints: int = 10) -> None:
"""Ticks will be placed between the major ticks.
The placement is linear for x between -linthresh and linthresh,
otherwise its logarithmically. nints gives the number of
intervals that will be bounded by the minor ticks.
Parameters
----------
linthresh : float
A single float which defines the range (-x, x), within which the plot is linear.
nints : int
Number of minor tick between major ticks. Defaults to 10
"""
self.linthresh = linthresh
self.nintervals = nints
def __call__(self) -> list[float]:
"""Return the locations of the ticks.
Returns
-------
list[float]
Minor ticks position.
"""
# Return the locations of the ticks
majorlocs = self.axis.get_majorticklocs()
if len(majorlocs) == 1:
return self.raise_if_exceeds(np.array([]))
# add temporary major tick locs at either end of the current range
# to fill in minor tick gaps
dmlower = majorlocs[1] - majorlocs[0] # major tick difference at lower end
dmupper = majorlocs[-1] - majorlocs[-2] # major tick difference at upper end
# add temporary major tick location at the lower end
if majorlocs[0] != 0.0 and (
(majorlocs[0] != self.linthresh and dmlower > self.linthresh)
or (dmlower == self.linthresh and majorlocs[0] < 0)
):
majorlocs = np.insert(majorlocs, 0, majorlocs[0] * 10.0)
else:
majorlocs = np.insert(majorlocs, 0, majorlocs[0] - self.linthresh)
# add temporary major tick location at the upper end
if majorlocs[-1] != 0.0 and (
(np.abs(majorlocs[-1]) != self.linthresh and dmupper > self.linthresh)
or (dmupper == self.linthresh and majorlocs[-1] > 0)
):
majorlocs = np.append(majorlocs, majorlocs[-1] * 10.0)
else:
majorlocs = np.append(majorlocs, majorlocs[-1] + self.linthresh)
# iterate through minor locs
minorlocs: list[float] = []
# handle the lowest part
for i in range(1, len(majorlocs)):
majorstep = majorlocs[i] - majorlocs[i - 1]
if abs(majorlocs[i - 1] + majorstep / 2) < self.linthresh:
ndivs = self.nintervals
else:
ndivs = self.nintervals - 1
minorstep = majorstep / ndivs
locs = np.arange(majorlocs[i - 1], majorlocs[i], minorstep)[1:]
minorlocs.extend(locs)
return self.raise_if_exceeds(np.array(minorlocs))
[docs]
def tick_values(self, _vmin: float, _vmax: float) -> None:
"""Return the values of the located ticks given **_vmin** and **_vmax** (not implemented).
Parameters
----------
_vmin : float
Minimum value.
_vmax : float
Maximum value.
Raises
------
NotImplementedError
Not used
"""
msg = f"Cannot get tick locations for a {type(self)} type."
raise NotImplementedError(msg)
BuiltinSubPlotLabelFormatFunctions: Mapping[str, Callable[[int, int | None], str]] = (
MappingProxyType(
{
"number": lambda x, _: f"{x}",
"upper_case_letter": format_sub_plot_number_upper_case_letter,
"lower_case_letter": lambda x, y: format_sub_plot_number_upper_case_letter(
x, y
).lower(),
}
)
)
[docs]
def add_subplot_labels(
axes: Axis | Axes,
*,
label_position: tuple[float, float] = (-0.05, 1.05),
label_coords: SubPlotLabelCoord = "axes fraction",
direction: Literal["row", "column"] = "row",
label_format_template: str = "{}",
label_format_function: BuiltinSubPlotLabelFormatFunctionKey
| Callable[[int, int | None], str] = "number",
fontsize: int = 16,
) -> None:
"""Add labels to all subplots in ``axes`` in a consistent manner.
Parameters
----------
axes : Axis | Axes
Axes (subplots) on which the labels should be added.
label_position : tuple[float, float]
Position of the label in ``label_coords`` coordinates.
label_coords : SubPlotLabelCoord
Coordinate system used for ``label_position``. Defaults to "axes fraction"
direction : Literal["row", "column"]
Direct in which the axes should be iterated in. Defaults to "row"
label_format_template : str
Template string to inject the return value of ``label_format_function`` into.
Defaults to "{}"
label_format_function : BuiltinSubPlotLabelFormatFunctionKey | Callable[[int, int | None], str]
Function to calculate the label for the axis index and ``axes`` size. Defaults to "number"
fontsize : int
Font size used for the label. Defaults to 16
"""
axes = ensure_axes_array(axes)
format_function = get_subplot_label_format_function(label_format_function)
if direction == "column":
axes = axes.T
for i, ax in enumerate(axes.flatten(), start=1):
ax.annotate(
label_format_template.format(format_function(i, axes.size)),
xy=label_position,
xycoords=label_coords,
fontsize=fontsize,
)