Source code for autoarray.plot.wrap.two_d.grid_scatter

import matplotlib.pyplot as plt
import numpy as np
import itertools
from typing import List, Union

from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D
from autoarray.structures.grids.uniform_2d import Grid2D
from autoarray.structures.grids.irregular_2d import Grid2DIrregular

from autoarray import exc


[docs]class GridScatter(AbstractMatWrap2D): """ Scatters an input set of grid points, for example (y,x) coordinates or data structures representing 2D (y,x) coordinates like a `Grid2D` or `Grid2DIrregular`. List of (y,x) coordinates are plotted with varying colors. This object wraps the following Matplotlib methods: - plt.scatter: https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.scatter.html There are a number of children of this method in the `mat_obj.py` module that plot specific sets of (y,x) points. Each of these objects uses uses their own config file and settings so that each has a unique appearance on every figure: - `OriginScatter`: plots the (y,x) coordinates of the origin of a data structure (e.g. as a black cross). - `MaskScatter`: plots a mask over an image, using the `Mask2d` object's (y,x) `edge_sub_1` property. - `BorderScatter: plots a border over an image, using the `Mask2d` object's (y,x) `border_sub_1` property. - `PositionsScatter`: plots the (y,x) coordinates that are input in a plotter via the `positions` input. - `IndexScatter`: plots specific (y,x) coordinates of a grid (or grids) via their 1d or 2d indexes. - `MeshGridScatter`: plots the grid of a `Mesh` object (see `autoarray.inversion`). Parameters ---------- colors : [str] The color or list of colors that the grid is plotted using. For plotting indexes or a grid list, a list of colors can be specified which the plot cycles through. """
[docs] def scatter_grid(self, grid: Union[np.ndarray, Grid2D]): """ Plot an input grid of (y,x) coordinates using the matplotlib method `plt.scatter`. Parameters ---------- grid : Grid2D The grid of (y,x) coordinates that is plotted. errors The error on every point of the grid that is plotted. """ config_dict = self.config_dict if len(config_dict["c"]) > 1: config_dict["c"] = config_dict["c"][0] try: plt.scatter(y=grid[:, 0], x=grid[:, 1], **config_dict) except (IndexError, TypeError): return self.scatter_grid_list(grid_list=grid)
[docs] def scatter_grid_list(self, grid_list: Union[List[Grid2D], List[Grid2DIrregular]]): """ Plot an input list of grids of (y,x) coordinates using the matplotlib method `plt.scatter`. This method colors each grid in each entry of the list the same, so that the different grids are visible in the plot. Parameters ---------- grid_list The list of grids of (y,x) coordinates that are plotted. """ if len(grid_list) == 0: return color = itertools.cycle(self.config_dict["c"]) config_dict = self.config_dict config_dict.pop("c") try: for grid in grid_list: plt.scatter(y=grid[:, 0], x=grid[:, 1], c=next(color), **config_dict) except IndexError: return None
[docs] def scatter_grid_colored( self, grid: Union[np.ndarray, Grid2D], color_array: np.ndarray, cmap: str ): """ Plot an input grid of (y,x) coordinates using the matplotlib method `plt.scatter`. The method colors the scattered grid according to an input ndarray of color values, using an input colormap. Parameters ---------- grid : Grid2D The grid of (y,x) coordinates that is plotted. color_array : ndarray The array of RGB color values used to color the grid. cmap The Matplotlib colormap used for the grid point coloring. """ config_dict = self.config_dict config_dict.pop("c") plt.scatter(y=grid[:, 0], x=grid[:, 1], c=color_array, cmap=cmap, **config_dict)
[docs] def scatter_grid_indexes( self, grid: Union[np.ndarray, Grid2D], indexes: np.ndarray ): """ Plot specific points of an input grid of (y,x) coordinates, which are specified according to the 1D or 2D indexes of the `Grid2D`. This method allows us to color in points on grids that map between one another. Parameters ---------- grid : Grid2D The grid of (y,x) coordinates that is plotted. indexes The 1D indexes of the grid that are colored in when plotted. """ if not isinstance(grid, (np.ndarray, Grid2D)): raise exc.PlottingException( "The grid passed into scatter_grid_indexes is not a ndarray and thus its" "1D indexes cannot be marked and plotted." ) if len(grid.shape) != 2: raise exc.PlottingException( "The grid passed into scatter_grid_indexes is not 2D (e.g. a flattened 1D" "grid) and thus its 1D indexes cannot be marked." ) if isinstance(indexes, list): if not any(isinstance(i, list) for i in indexes): indexes = [indexes] color = itertools.cycle(self.config_dict["c"]) config_dict = self.config_dict config_dict.pop("c") for index_list in indexes: if all([isinstance(index, float) for index in index_list]) or all( [isinstance(index, int) for index in index_list] ): plt.scatter( y=grid[index_list, 0], x=grid[index_list, 1], color=next(color), **config_dict, ) elif all([isinstance(index, tuple) for index in index_list]) or all( [isinstance(index, list) for index in index_list] ): ys, xs = map(list, zip(*index_list)) plt.scatter( y=grid.native[ys, xs, 0], x=grid.native[ys, xs, 1], color=next(color), **config_dict, ) else: raise exc.PlottingException( "The indexes input into the grid_scatter_index method do not conform to a " "useable type" )