Source code for autoarray.plot.wrap.one_d.yx_plot

import matplotlib.pyplot as plt
import numpy as np
from typing import Union

from autoarray.plot.wrap.one_d.abstract import AbstractMatWrap1D
from autoarray.structures.arrays.uniform_1d import Array1D

from autoarray import exc

[docs]class YXPlot(AbstractMatWrap1D): def __init__(self, plot_axis_type=None, label=None, **kwargs): """ Plots 1D data structures as a y vs x figure. This object wraps the following Matplotlib methods: - plt.plot: """ super().__init__(**kwargs) self.plot_axis_type = plot_axis_type self.label = label
[docs] def plot_y_vs_x( self, y: Union[np.ndarray, Array1D], x: Union[np.ndarray, Array1D], label: str = None, plot_axis_type=None, y_errors=None, x_errors=None, y_extra=None, y_extra_2=None, ls_errorbar="", ): """ Plots 1D y-data against 1D x-data using the matplotlib method `plt.plot`, `plt.semilogy`, `plt.loglog`, or `plt.scatter`. Parameters ---------- y The ydata that is plotted. x The xdata that is plotted. plot_axis_type The method used to make the plot that defines the scale of the axes {"linear", "semilogy", "loglog", "scatter"}. label Optionally include a label on the plot for a `Legend` to display. """ if self.label is not None: label = self.label if plot_axis_type == "linear" or plot_axis_type == "symlog": plt.plot(x, y, label=label, **self.config_dict) elif plot_axis_type == "semilogy": plt.semilogy(x, y, label=label, **self.config_dict) elif plot_axis_type == "loglog": plt.loglog(x, y, label=label, **self.config_dict) elif plot_axis_type == "scatter": plt.scatter(x, y, label=label, **self.config_dict) elif plot_axis_type == "errorbar" or plot_axis_type == "errorbar_logy": plt.errorbar( x, y, yerr=y_errors, xerr=x_errors, # marker="o", fmt="o", # ls=ls_errorbar, **self.config_dict ) if plot_axis_type == "errorbar_logy": plt.yscale("log") else: raise exc.PlottingException( "The plot_axis_type supplied to the plotter is not a valid string (must be linear " "{semilogy, loglog})" ) if y_extra is not None: if isinstance(y_extra, list): for y_extra_ in y_extra: plt.plot(x, y_extra_) else: plt.plot(x, y_extra, c="r") if y_extra_2 is not None: plt.plot(x, y_extra_2, c="r")