Source code for pyglotaran_extras.plotting.plot_traces

"""Module containing functionality to plot fitted traces."""

from __future__ import annotations

from typing import TYPE_CHECKING
from warnings import warn

import matplotlib.pyplot as plt

from pyglotaran_extras.io.utils import result_dataset_mapping
from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import MinorSymLogLocator
from pyglotaran_extras.plotting.utils import PlotDuplicationWarning
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.plotting.utils import add_unique_figure_legend
from pyglotaran_extras.plotting.utils import extract_dataset_scale
from pyglotaran_extras.plotting.utils import extract_irf_location
from pyglotaran_extras.plotting.utils import select_plot_wavelengths

__all__ = ["select_plot_wavelengths", "plot_fitted_traces"]

if TYPE_CHECKING:
    from collections.abc import Iterable

    from cycler import Cycler
    from matplotlib.axis import Axis
    from matplotlib.figure import Figure
    from matplotlib.pyplot import Axes

    from pyglotaran_extras.types import ResultLike


[docs] def plot_data_and_fits( result: ResultLike, wavelength: float, axis: Axis, center_λ: float | None = None, main_irf_nr: int = 0, linlog: bool = False, linthresh: float = 1, divide_by_scale: bool = True, per_axis_legend: bool = False, y_label: str = "a.u.", cycler: Cycler | None = PlotStyle().data_cycler_solid, show_zero_line: bool = True, ) -> None: """Plot data and fits for a given ``wavelength`` on a given ``axis``. If the wavelength isn't part of a dataset, that dataset will be skipped. Parameters ---------- result : ResultLike Data structure which can be converted to a mapping. wavelength : float Wavelength to plot data and fits for. axis : Axis Axis to plot the data and fits on. center_λ : float | None Center wavelength (λ in nm) main_irf_nr : int Index of the main ``irf`` component when using an ``irf`` parametrized with multiple peaks. Defaults to 0. linlog : bool Whether to use 'symlog' scale or not. Defaults to False. linthresh : float A single float which defines the range (-x, x), within which the plot is linear. This avoids having the plot go to infinity around zero. Defaults to 1. divide_by_scale : bool Whether or not to divide the data by the dataset scale used for optimization. Defaults to True. per_axis_legend : bool Whether to use a legend per plot or for the whole figure. Defaults to False. y_label : str Label used for the y-axis of each subplot. cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().data_cycler_solid. show_zero_line : bool Whether or not to add a horizontal line at zero. Defaults to True. See Also -------- plot_fit_overview """ result_map = result_dataset_mapping(result) add_cycler_if_not_none(axis, cycler) for dataset_name in result_map: if result_map[dataset_name].coords["time"].to_numpy().size == 1: continue spectral_coords = result_map[dataset_name].coords["spectral"].to_numpy() if spectral_coords.min() <= wavelength <= spectral_coords.max(): result_data = result_map[dataset_name].sel(spectral=[wavelength], method="nearest") scale = extract_dataset_scale(result_data, divide_by_scale) irf_loc = extract_irf_location(result_data, center_λ, main_irf_nr) result_data = result_data.assign_coords(time=result_data.coords["time"] - irf_loc) (result_data.data / scale).plot(x="time", ax=axis, label=f"{dataset_name}_data") (result_data.fitted_data / scale).plot(x="time", ax=axis, label=f"{dataset_name}_fit") else: [next(axis._get_lines.prop_cycler) for _ in range(2)] if linlog: axis.set_xscale("symlog", linthresh=linthresh) axis.xaxis.set_minor_locator(MinorSymLogLocator(linthresh)) if show_zero_line is True: axis.axhline(0, color="k", linewidth=1) axis.set_ylabel(y_label) if per_axis_legend is True: axis.legend()
[docs] def plot_fitted_traces( result: ResultLike, wavelengths: Iterable[float], axes_shape: tuple[int, int] = (4, 4), center_λ: float | None = None, main_irf_nr: int = 0, linlog: bool = False, linthresh: float = 1, divide_by_scale: bool = True, per_axis_legend: bool = False, figsize: tuple[float, float] = (30, 15), title: str = "Fit overview", y_label: str = "a.u.", cycler: Cycler | None = PlotStyle().data_cycler_solid, show_zero_line: bool = True, ) -> tuple[Figure, Axes]: """Plot data and their fit in per wavelength plot grid. Parameters ---------- result : ResultLike Data structure which can be converted to a mapping of datasets. wavelengths : Iterable[float] Wavelength which should be used for each subplot, should to be of length N*M with ``axes_shape`` being of shape (N, M), else it will result in missing plots. axes_shape : tuple[int, int] Shape of the plot grid (N, M). Defaults to (4, 4). center_λ : float | None Center wavelength of the IRF (λ in nm). main_irf_nr : int Index of the main ``irf`` component when using an ``irf`` parametrized with multiple peaks. Defaults to 0. linlog : bool Whether to use 'symlog' scale or not. Defaults to False. linthresh : float A single float which defines the range (-x, x), within which the plot is linear. This avoids having the plot go to infinity around zero. Defaults to 1. divide_by_scale : bool Whether or not to divide the data by the dataset scale used for optimization. Defaults to True. per_axis_legend : bool Whether to use a legend per plot or for the whole figure. Defaults to False. figsize : tuple[float, float] Size of the figure (N, M) in inches. Defaults to (30, 15). title : str Title to add to the figure. Defaults to "Fit overview". y_label : str Label used for the y-axis of each subplot. cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().data_cycler_solid. show_zero_line : bool Whether or not to add a horizontal line at zero. Defaults to True. Returns ------- tuple[Figure, Axes] Figure and axes which can then be refined by the user. See Also -------- maximum_coordinate_range add_unique_figure_legend plot_data_and_fits calculate_wavelengths """ result_map = result_dataset_mapping(result) fig, axes = plt.subplots(*axes_shape, figsize=figsize) nr_of_plots = len(axes.flatten()) max_spectral_values = max( len(result_map[dataset_name].coords["spectral"]) for dataset_name in result_map ) if nr_of_plots > max_spectral_values: warn( PlotDuplicationWarning( f"The number of plots ({nr_of_plots}) exceeds the maximum number of " f"spectral data points ({max_spectral_values}), " "which will lead in duplicated plots." ), stacklevel=2, ) for wavelength, axis in zip(wavelengths, axes.flatten(), strict=True): plot_data_and_fits( result=result_map, wavelength=wavelength, axis=axis, center_λ=center_λ, main_irf_nr=main_irf_nr, linlog=linlog, linthresh=linthresh, divide_by_scale=divide_by_scale, per_axis_legend=per_axis_legend, y_label=y_label, cycler=cycler, show_zero_line=show_zero_line, ) if per_axis_legend is False: add_unique_figure_legend(fig, axes) fig.suptitle(title, fontsize=28) fig.tight_layout() return fig, axes