Source code for autoarray.plot.wrap.base.ticks

import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import numpy as np
from typing import List, Tuple, Optional

from autoconf import conf

from autoarray.plot.wrap.base.abstract import AbstractMatWrap
from autoarray.plot.wrap.base.units import Units


class TickMaker:
    def __init__(
        self,
        min_value: float,
        max_value: float,
        factor: float,
        number_of_ticks: int,
        units,
    ):
        self.min_value = min_value
        self.max_value = max_value
        self.factor = factor
        self.number_of_ticks = number_of_ticks
        self.units = units

    @property
    def centre(self):
        return self.max_value - ((self.max_value - self.min_value) / 2.0)

    @property
    def tick_values_linear(self):
        value_0 = self.centre - ((self.centre - self.max_value)) * self.factor
        value_1 = self.centre + ((self.min_value - self.centre)) * self.factor

        return np.linspace(value_0, value_1, self.number_of_ticks)

    @property
    def tick_values_log10(self):
        min_value = self.min_value

        if self.min_value < 0.001:
            min_value = 0.001

        min_value = 10 ** np.floor(np.log10(min_value))

        max_value = 10 ** np.ceil(np.log10(self.max_value))
        number = int(abs(np.log10(max_value) - np.log10(min_value))) + 1

        return np.logspace(np.log10(min_value), np.log10(max_value), number)

    @property
    def tick_values_integers(self):
        ticks = np.arange(int(self.max_value - self.min_value))

        if not self.units.use_scaled:
            ticks = ticks.astype("int")

        return ticks


class LabelMaker:
    def __init__(
        self,
        tick_values,
        min_value: float,
        max_value: float,
        units,
        pixels: Optional[int] = None,
        round_sf: int = 2,
        yunit=None,
        xunit=None,
        manual_suffix=None,
    ):
        self.tick_values = tick_values
        self.min_value = min_value
        self.max_value = max_value
        self.units = units
        self.pixels = pixels
        self.convert_factor = self.units.ticks_convert_factor or 1.0
        self.yunit = yunit
        self.xunit = xunit
        self.round_sf = round_sf
        self.manual_suffix = manual_suffix

    @property
    def suffix(self) -> Optional[str]:
        """
        Returns the label of an object, by determining it from the figure units if the label is not manually specified.

        Parameters
        ----------
        units
           The units of the data structure that is plotted which informs the appropriate label text.
        """

        if self.manual_suffix is not None:
            return self.manual_suffix

        if self.yunit is not None:
            return self.yunit

        if self.xunit is not None:
            return self.xunit

        if self.units.ticks_label is not None:
            return self.units.ticks_label

        units_conf = conf.instance["visualize"]["general"]["units"]

        if self.units is None:
            return units_conf["unscaled_symbol"]

        if self.units.use_scaled:
            return units_conf["scaled_symbol"]

        return units_conf["unscaled_symbol"]

    @property
    def span(self):
        return self.max_value - self.min_value

    @property
    def tick_values_rounded(self):
        values = np.asarray(self.tick_values) * self.convert_factor
        values_positive = np.where(
            np.isfinite(values) & (values != 0),
            np.abs(values),
            10 ** (self.round_sf - 1),
        )
        mags = 10 ** (self.round_sf - 1 - np.floor(np.log10(values_positive)))
        return np.round(values * mags) / mags

    @property
    def labels_linear(self):
        if self.units.use_raw:
            return self.with_appended_suffix(self.tick_values_rounded)

        if not self.units.use_scaled and self.yunit is None:
            return self.labels_linear_pixels

        labels = np.asarray([value for value in self.tick_values_rounded])

        if not self.units.use_scaled and self.yunit is None:
            labels = [f"{int(label)}" for label in labels]
        return self.with_appended_suffix(labels)

    @property
    def labels_linear_pixels(self):
        if self.max_value == self.min_value:
            labels = [f"{int(label)}" for label in self.tick_values]
            return self.with_appended_suffix(labels)

        ticks_from_zero = [tick - self.min_value for tick in self.tick_values]
        labels = [(tick / self.span) * self.pixels for tick in ticks_from_zero]

        labels = [f"{int(label)}" for label in labels]

        return self.with_appended_suffix(labels)

    @property
    def labels_log10(self):
        labels = ["{:.0e}".format(label) for label in self.tick_values]
        labels = [label.replace("1e", "$10^{") + "}$" for label in labels]
        labels = [
            label.replace("{-0", "{-").replace("{+0", "{+").replace("+", "")
            for label in labels
        ]
        #     labels = [label.replace("1e", "").replace("-0", "-").replace("+0", "+").replace("+0", "0") for label in labels]

        return self.with_appended_suffix(labels)

    def with_appended_suffix(self, labels):
        """
        The labels used for the y and x ticks can be append with a suffix.

        For example, if the labels were [-1.0, 0.0, 1.0] and the suffix is ", the labels with the suffix appended
        is [-1.0", 0.0", 1.0"].

        Parameters
        ----------
        labels
            The y and x labels which are append with the suffix.
        """

        labels = [str(label) for label in labels]

        all_end_0 = True

        for label in labels:
            if not label.endswith(".0"):
                all_end_0 = False

        if all_end_0:
            labels = [label[:-2] for label in labels]

        return [f"{label}{self.suffix}" for label in labels]


class AbstractTicks(AbstractMatWrap):
    def __init__(
        self,
        manual_factor: Optional[float] = None,
        manual_values: Optional[List[float]] = None,
        manual_min_max_value: Optional[Tuple[float, float]] = None,
        manual_units: Optional[str] = None,
        manual_suffix: Optional[str] = None,
        **kwargs,
    ):
        """
        The settings used to customize a figure's y and x ticks using the `YTicks` and `XTicks` objects.

        This object wraps the following Matplotlib methods:

        - plt.yticks: https://matplotlib.org/3.3.1/api/_as_gen/matplotlib.pyplot.yticks.html
        - plt.xticks: https://matplotlib.org/3.3.1/api/_as_gen/matplotlib.pyplot.xticks.html

        Parameters
        ----------
        manual_values
            Manually override the tick labels to display the labels as the input list of floats.
        manual_units
            Manually override the units in brackets of the tick label.
        manual_suffix
            A suffix applied to every tick label (e.g. for the suffix `kpc` 0.0 becomes 0.0kpc).
        """
        super().__init__(**kwargs)

        self.manual_factor = manual_factor
        self.manual_values = manual_values
        self.manual_min_max_value = manual_min_max_value
        self.manual_units = manual_units
        self.manual_suffix = manual_suffix

    def factor_from(self, suffix):
        if self.manual_factor is not None:
            return self.manual_factor
        return conf.instance["visualize"][self.config_folder][self.__class__.__name__][
            "manual"
        ][f"extent_factor{suffix}"]

    def number_of_ticks_from(self, suffix):
        return conf.instance["visualize"][self.config_folder][self.__class__.__name__][
            "manual"
        ][f"number_of_ticks{suffix}"]

    def tick_maker_from(
        self, min_value: float, max_value: float, units, is_for_1d_plot: bool
    ):
        suffix = "_1d" if is_for_1d_plot else "_2d"

        factor = self.factor_from(suffix=suffix)
        number_of_ticks = self.number_of_ticks_from(suffix=suffix)

        return TickMaker(
            min_value=min_value,
            max_value=max_value,
            factor=factor,
            units=units,
            number_of_ticks=number_of_ticks,
        )

    def ticks_from(
        self,
        min_value: float,
        max_value: float,
        units: Units,
        is_log10: bool = False,
        is_for_1d_plot: bool = False,
    ):
        tick_maker = self.tick_maker_from(
            min_value=min_value,
            max_value=max_value,
            units=units,
            is_for_1d_plot=is_for_1d_plot,
        )

        if self.manual_values:
            return self.manual_values
        elif is_log10:
            return tick_maker.tick_values_log10
        return tick_maker.tick_values_linear

    def labels_from(
        self,
        ticks,
        min_value: float,
        max_value: float,
        units,
        yunit,
        xunit,
        pixels: Optional[int] = None,
        is_log10: bool = False,
    ):
        label_maker = LabelMaker(
            tick_values=ticks,
            min_value=min_value,
            max_value=max_value,
            units=units,
            pixels=pixels,
            yunit=yunit,
            xunit=xunit,
            manual_suffix=self.manual_suffix,
        )

        if self.manual_units:
            return ticks
        elif is_log10:
            return label_maker.labels_log10
        return label_maker.labels_linear

    def ticks_and_labels_from(
        self,
        min_value,
        max_value,
        units,
        pixels: Optional[int] = None,
        use_integers: bool = False,
        yunit=None,
        xunit=None,
        is_log10: bool = False,
        is_for_1d_plot: bool = False,
    ):
        if use_integers:
            ticks = np.arange(int(max_value - min_value))
            return ticks, ticks

        ticks = self.ticks_from(
            min_value=min_value,
            max_value=max_value,
            units=units,
            is_log10=is_log10,
            is_for_1d_plot=is_for_1d_plot,
        )

        labels = self.labels_from(
            ticks=ticks,
            min_value=min_value,
            max_value=max_value,
            units=units,
            yunit=yunit,
            xunit=xunit,
            pixels=pixels,
            is_log10=is_log10,
        )
        return ticks, labels


[docs]class YTicks(AbstractTicks):
[docs] def set( self, min_value: float, max_value: float, units: Units, pixels: Optional[int] = None, yunit=None, is_for_1d_plot: bool = False, is_log10: bool = False, ): """ Set the y ticks of a figure using the shape of an input `Array2D` object and input units. Parameters ---------- array The 2D array of data which is plotted. min_value the minimum value of the yticks that figure is plotted using. max_value the maximum value of the yticks that figure is plotted using. units The units of the figure. """ if self.manual_min_max_value: min_value = self.manual_min_max_value[0] max_value = self.manual_min_max_value[1] ticks, labels = self.ticks_and_labels_from( min_value=min_value, max_value=max_value, units=units, pixels=pixels, yunit=yunit, is_log10=is_log10, is_for_1d_plot=is_for_1d_plot, ) if is_log10: plt.ylim(max(min_value, self.log10_min_value), max_value) if not is_for_1d_plot and not units.use_scaled: labels = reversed(labels) plt.yticks(ticks=ticks, labels=labels, **self.config_dict) if self.manual_units is not None: plt.gca().yaxis.set_major_formatter( FormatStrFormatter(f"{self.manual_units}") )
[docs]class XTicks(AbstractTicks):
[docs] def set( self, min_value: float, max_value: float, units: Units, pixels: Optional[int] = None, xunit=None, use_integers=False, is_for_1d_plot: bool = False, is_log10: bool = False, ): """ Set the x ticks of a figure using the shape of an input `Array2D` object and input units. Parameters ---------- array The 2D array of data which is plotted. min_value the minimum value of the xticks that figure is plotted using. max_value the maximum value of the xticks that figure is plotted using. units The units of the figure. """ if self.manual_min_max_value: min_value = self.manual_min_max_value[0] max_value = self.manual_min_max_value[1] ticks, labels = self.ticks_and_labels_from( min_value=min_value, max_value=max_value, pixels=pixels, units=units, yunit=xunit, use_integers=use_integers, is_for_1d_plot=is_for_1d_plot, is_log10=is_log10, ) plt.xticks(ticks=ticks, labels=labels, **self.config_dict) if self.manual_units is not None: plt.gca().xaxis.set_major_formatter( FormatStrFormatter(f"{self.manual_units}") )