Source code for autoarray.plot.multi_plotters

from typing import Optional, Tuple

from autoarray.plot.wrap.base.ticks import YTicks
from autoarray.plot.wrap.base.ticks import XTicks


[docs]class MultiFigurePlotter: def __init__( self, plotter_list, subplot_shape: Tuple[int, int] = None, subplot_title: Optional[str] = None, ): self.plotter_list = plotter_list self.subplot_shape = subplot_shape self.subplot_title = subplot_title def subplot_of_figure(self, func_name, figure_name, filename_suffix="", **kwargs): number_subplots = len(self.plotter_list) self.plotter_list[0].open_subplot_figure( number_subplots=number_subplots, subplot_shape=self.subplot_shape ) for i, plotter in enumerate(self.plotter_list): try: plotter.mat_plot_2d.set_for_subplot(is_for_subplot=True) plotter.mat_plot_2d.number_subplots = number_subplots plotter.mat_plot_2d.subplot_shape = self.subplot_shape plotter.mat_plot_2d.subplot_index = i + 1 except AttributeError: plotter.mat_plot_1d.set_for_subplot(is_for_subplot=True) plotter.mat_plot_1d.number_subplots = number_subplots plotter.mat_plot_1d.subplot_shape = self.subplot_shape plotter.mat_plot_1d.subplot_index = i + 1 func = getattr(plotter, func_name) if figure_name is None: func(**{**{}, **kwargs}) else: func(**{**{figure_name: True}, **kwargs}) if self.plotter_list[0].mat_plot_1d is not None: self.plotter_list[0].mat_plot_1d.output.subplot_to_figure( auto_filename=f"subplot_{figure_name}{filename_suffix}" ) if self.plotter_list[0].mat_plot_2d is not None: self.plotter_list[0].mat_plot_2d.output.subplot_to_figure( auto_filename=f"subplot_{figure_name}{filename_suffix}" ) self.plotter_list[0].close_subplot_figure() def subplot_of_multi_yx_1d(self, filename_suffix="", **kwargs): number_subplots = len(self.plotter_list) self.plotter_list[0].plotter_list[0].open_subplot_figure( number_subplots=number_subplots, subplot_shape=self.subplot_shape, subplot_title=self.subplot_title, ) for i, plotter in enumerate(self.plotter_list): for plott in plotter.plotter_list: plott.mat_plot_1d.set_for_subplot(is_for_subplot=True) plott.mat_plot_1d.number_subplots = number_subplots plott.mat_plot_1d.subplot_shape = self.subplot_shape plott.mat_plot_1d.subplot_index = i + 1 func = getattr(plotter, "figure_1d") func( **{ **{ "func_name": "figure_1d", "figure_name": None, "is_for_subplot": True, }, **kwargs, } ) self.plotter_list[0].plotter_list[0].mat_plot_1d.output.subplot_to_figure( auto_filename=f"subplot_{filename_suffix}" ) self.plotter_list[0].plotter_list[0].close_subplot_figure()
[docs]class MultiYX1DPlotter: def __init__( self, plotter_list, color_list=None, legend_labels=None, y_manual_min_max_value=None, x_manual_min_max_value=None, ): self.plotter_list = plotter_list if color_list is None: color_list = 10 * ["k", "r", "b", "g", "c", "m", "y"] self.color_list = color_list self.legend_labels = legend_labels self.y_manual_min_max_value = y_manual_min_max_value self.x_manual_min_max_value = x_manual_min_max_value def figure_1d(self, func_name, figure_name, is_for_subplot=False, **kwargs): if not is_for_subplot: self.plotter_list[0].mat_plot_1d.figure.open() for i, plotter in enumerate(self.plotter_list): plotter.set_mat_plot_1d_for_multi_plot( is_for_multi_plot=True, color=self.color_list[i], yticks=self.yticks, xticks=self.xticks, ) if self.legend_labels is not None: plotter.mat_plot_1d.yx_plot.label = self.legend_labels[i] func = getattr(plotter, func_name) if figure_name is None: func(**{**{}, **kwargs}) else: func(**{**{figure_name: True}, **kwargs}) plotter.set_mat_plot_1d_for_multi_plot(is_for_multi_plot=False, color=None) if not is_for_subplot: self.plotter_list[0].mat_plot_1d.output.subplot_to_figure( auto_filename=f"multi_{figure_name}" ) self.plotter_list[0].mat_plot_1d.figure.close() @property def yticks(self): # TODO: Need to make this work for all plotters, rather than just y x, for example # TODO : GalaxyPlotters where y and x are computed inside the function called via # TODO : func(**{**{figure_name: True}, **kwargs}) if self.y_manual_min_max_value is not None: return YTicks(manual_min_max_value=self.y_manual_min_max_value) try: min_value = min([min(plotter.y) for plotter in self.plotter_list]) max_value = max([max(plotter.y) for plotter in self.plotter_list]) except AttributeError: return return YTicks(manual_min_max_value=(min_value, max_value)) @property def xticks(self): if self.x_manual_min_max_value is not None: return XTicks(manual_min_max_value=self.x_manual_min_max_value) try: min_value = min([min(plotter.x) for plotter in self.plotter_list]) max_value = max([max(plotter.x) for plotter in self.plotter_list]) except AttributeError: return return XTicks(manual_min_max_value=(min_value, max_value))