Source code for autogalaxy.util.plot_utils

import logging
import os
from pathlib import Path
import numpy as np

logger = logging.getLogger(__name__)


def _to_lines(*items):
    """Convert multiple line sources into a flat list of (N,2) numpy arrays.

    Each item may be ``None`` (skipped), a list of line-like objects, or a
    single line-like object.  A line-like object is anything that either has
    an ``.array`` attribute or can be coerced to a 2-D numpy array with shape
    ``(N, 2)``.  Items that cannot be converted, or that are empty, are
    silently dropped.

    Parameters
    ----------
    *items
        Any number of line sources to merge.

    Returns
    -------
    list of np.ndarray or None
        A flat list of ``(N, 2)`` arrays, or ``None`` if nothing valid was
        found.
    """
    result = []
    for item in items:
        if item is None:
            continue
        if isinstance(item, list):
            for sub in item:
                try:
                    arr = np.array(sub.array if hasattr(sub, "array") else sub)
                    if arr.ndim == 2 and arr.shape[1] == 2 and len(arr) > 0:
                        result.append(arr)
                except Exception:
                    pass
        else:
            try:
                arr = np.array(item.array if hasattr(item, "array") else item)
                if arr.ndim == 2 and arr.shape[1] == 2 and len(arr) > 0:
                    result.append(arr)
            except Exception:
                pass
    return result or None


def _to_positions(*items):
    """Convert multiple position sources into a flat list of (N,2) numpy arrays.

    Thin wrapper around :func:`_to_lines` — positions and lines share the same
    underlying representation (lists of ``(N, 2)`` coordinate arrays).

    Parameters
    ----------
    *items
        Any number of position sources to merge.

    Returns
    -------
    list of np.ndarray or None
        A flat list of ``(N, 2)`` arrays, or ``None`` if nothing valid was
        found.
    """
    return _to_lines(*items)


def _save_subplot(fig, output_path, output_filename, output_format=None,
                  dpi=300):
    """Save a subplot figure to disk (or show it if output_format/output_path say so).

    For FITS output use the dedicated ``fits_*`` functions instead.
    """
    import matplotlib.pyplot as plt
    from autoarray.plot.utils import _output_mode_save, _conf_output_format, _FAST_PLOTS

    if _output_mode_save(fig, output_filename):
        return

    if _FAST_PLOTS:
        plt.close(fig)
        return

    fmt = output_format[0] if isinstance(output_format, (list, tuple)) else (output_format or _conf_output_format())
    if fmt == "show" or not output_path:
        plt.show()
    else:
        os.makedirs(str(output_path), exist_ok=True)
        fpath = Path(output_path) / f"{output_filename}.{fmt}"
        fig.savefig(fpath, dpi=dpi, bbox_inches="tight", pad_inches=0.1)
    plt.close(fig)


def _resolve_colormap(colormap):
    """Resolve 'default' or None to the autoarray default colormap."""
    if colormap in ("default", None):
        from autoarray.plot.utils import _default_colormap
        return _default_colormap()
    return colormap


def _resolve_format(output_format):
    """Normalise output_format: accept a list/tuple or a plain string."""
    from autoarray.plot.utils import _conf_output_format

    if isinstance(output_format, (list, tuple)):
        return output_format[0]
    return output_format or _conf_output_format()


def _numpy_grid(grid):
    """Convert a grid-like object to a numpy array, or return None."""
    if grid is None:
        return None
    try:
        return np.array(grid.array if hasattr(grid, "array") else grid)
    except Exception:
        return None


[docs] def plot_array( array, title="", output_path=None, output_filename="array", output_format=None, colormap="default", use_log10=False, vmin=None, vmax=None, symmetric=False, positions=None, lines=None, line_colors=None, grid=None, cb_unit=None, ax=None, ): """Plot an autoarray ``Array2D`` to file or onto an existing ``Axes``. All array preprocessing (zoom, mask-edge extraction, native/extent unpacking) is handled internally so callers never need to duplicate it. The actual rendering is delegated to ``autoarray.plot.plot_array``. Parameters ---------- array The ``Array2D`` (or array-like) to plot. title : str Title displayed above the panel. output_path : str or None Directory in which to save the figure. ``None`` → call ``plt.show()`` instead. output_filename : str Stem of the output file name (extension is added from *output_format*). output_format : str File format, e.g. ``"png"`` or ``"pdf"``. colormap : str Matplotlib colormap name, or ``"default"`` to use the autoarray default (``"jet"``). use_log10 : bool If ``True`` apply a log₁₀ stretch to the array values. vmin, vmax : float or None Explicit colour-bar limits. Ignored when *symmetric* is ``True``. symmetric : bool If ``True`` set ``vmin = -vmax`` so that zero maps to the middle of the colormap. positions : list or array-like or None Point positions to scatter-plot over the image. lines : list or array-like or None Line coordinates to overlay on the image. line_colors : list or None Colours for each entry in *lines*. grid : array-like or None An additional grid of points to overlay. ax : matplotlib.axes.Axes or None Existing ``Axes`` to draw into. When provided the figure is *not* saved — the caller is responsible for saving. """ from autoarray.plot import plot_array as _aa_plot_array colormap = _resolve_colormap(colormap) output_format = _resolve_format(output_format) if symmetric: try: arr = array.native.array except AttributeError: arr = np.asarray(array) finite = arr[np.isfinite(arr)] abs_max = float(np.max(np.abs(finite))) if len(finite) > 0 else 1.0 vmin, vmax = -abs_max, abs_max _positions_list = positions if isinstance(positions, list) else _to_positions(positions) _lines_list = lines if isinstance(lines, list) else _to_lines(lines) if ax is not None: _output_path = None else: _output_path = output_path if output_path is not None else "." _aa_plot_array( array=array, ax=ax, grid=_numpy_grid(grid), positions=_positions_list, lines=_lines_list, line_colors=line_colors, title=title or "", colormap=colormap, use_log10=use_log10, vmin=vmin, vmax=vmax, cb_unit=cb_unit, output_path=_output_path, output_filename=output_filename, output_format=output_format, )
def _fits_values_and_header(array): """Extract raw numpy values and header dict from an autoarray object. Returns ``(values, header_dict, ext_name)`` where *header_dict* and *ext_name* may be ``None`` for plain arrays. """ from autoarray.structures.visibilities import AbstractVisibilities from autoarray.mask.abstract_mask import Mask if isinstance(array, AbstractVisibilities): return np.asarray(array.in_array), None, None if isinstance(array, Mask): header = array.header_dict if hasattr(array, "header_dict") else None return np.asarray(array.astype("float")), header, "mask" if hasattr(array, "native"): try: header = array.mask.header_dict except (AttributeError, TypeError): header = None return np.asarray(array.native.array).astype("float"), header, None return np.asarray(array), None, None def fits_array(array, file_path, overwrite=False, ext_name=None): """Write an autoarray ``Array2D``, ``Mask2D``, or array-like to a ``.fits`` file. Handles header metadata (pixel scales, origin) automatically for autoarray objects. Parameters ---------- array The data to write. file_path : str or Path Full path including filename and ``.fits`` extension. overwrite : bool If ``True`` an existing file at *file_path* is replaced. ext_name : str or None FITS extension name. Auto-detected for masks (``"mask"``). """ from autoconf.fitsable import output_to_fits values, header_dict, auto_ext_name = _fits_values_and_header(array) if ext_name is None: ext_name = auto_ext_name output_to_fits( values=values, file_path=file_path, overwrite=overwrite, header_dict=header_dict, ext_name=ext_name, )
[docs] def plot_grid( grid, title="", output_path=None, output_filename="grid", output_format=None, lines=None, ax=None, ): """Plot an autoarray ``Grid2D`` as a scatter plot. Delegates to ``autoarray.plot.plot_grid`` after converting the grid to a plain numpy array. Parameters ---------- grid The ``Grid2D`` (or grid-like) to plot. title : str Title displayed above the panel. output_path : str or None Directory in which to save the figure. ``None`` → call ``plt.show()`` instead. output_filename : str Stem of the output file name. output_format : str File format, e.g. ``"png"``. lines : list or None Line coordinates to overlay on the grid plot. ax : matplotlib.axes.Axes or None Existing ``Axes`` to draw into. """ from autoarray.plot import plot_grid as _aa_plot_grid output_format = _resolve_format(output_format) if ax is not None: _output_path = None else: _output_path = output_path if output_path is not None else "." _aa_plot_grid( grid=np.array(grid.array if hasattr(grid, "array") else grid), ax=ax, title=title or "", output_path=_output_path, output_filename=output_filename, output_format=output_format, )
def _critical_curves_method(): """Read ``general.critical_curves_method`` from the visualize config. Returns ``"marching_squares"`` (the default) or ``"zero_contour"``. Any unrecognised value falls back to ``"marching_squares"`` with a warning. If ``"zero_contour"`` is requested but ``jax_zero_contour`` is not installed (e.g. Python <3.11), falls back to ``"marching_squares"`` with a warning. """ from autoconf import conf try: method = conf.instance["visualize"]["general"]["general"]["critical_curves_method"] except (KeyError, TypeError): method = "marching_squares" if method not in ("zero_contour", "marching_squares"): logger.warning( f"visualize/general.yaml: unrecognised critical_curves_method " f"'{method}'. Falling back to 'marching_squares'." ) return "marching_squares" if method == "zero_contour": try: import jax_zero_contour # noqa: F401 except ImportError: logger.warning( "critical_curves_method='zero_contour' requested, but " "jax_zero_contour is not installed (Python <3.11 ships without " "the [jax] extra). Falling back to 'marching_squares'." ) return "marching_squares" return method def _caustics_from(mass_obj, grid): """Compute tangential and radial caustics for a mass object via LensCalc. The algorithm used is controlled by ``general.critical_curves_method`` in ``visualize/general.yaml``: - ``"zero_contour"`` *(default)* — uses ``jax_zero_contour`` to trace the zero contour of each eigen value directly. No dense evaluation grid is needed; a coarse 25 × 25 scan finds the seed points automatically. - ``"marching_squares"`` — evaluates eigen values on the full *grid* and uses marching squares to find the contours. Parameters ---------- mass_obj Any object understood by ``LensCalc.from_mass_obj`` (e.g. a :class:`~autogalaxy.galaxy.galaxies.Galaxies` or autolens ``Tracer``). grid : aa.type.Grid2DLike The grid on which to evaluate the caustics (used only for the ``"marching_squares"`` path; ignored by ``"zero_contour"``). Returns ------- tuple[list, list] ``(tangential_caustics, radial_caustics)``. """ if os.environ.get("PYAUTO_FAST_PLOTS") == "1": return [], [] from autogalaxy.operate.lens_calc import LensCalc od = LensCalc.from_mass_obj(mass_obj) method = _critical_curves_method() if method == "zero_contour": tan_ca = od.tangential_caustic_list_via_zero_contour_from() rad_ca = od.radial_caustic_list_via_zero_contour_from() else: tan_ca = od.tangential_caustic_list_from(grid=grid) rad_ca = od.radial_caustic_list_from(grid=grid) return tan_ca, rad_ca def _critical_curves_from(mass_obj, grid, tc=None, rc=None): """Compute tangential and radial critical curves for a mass object. If *tc* is already provided it is returned unchanged (along with *rc*), allowing callers to cache the curves across multiple plot calls. The algorithm used when *tc* is ``None`` is controlled by ``general.critical_curves_method`` in ``visualize/general.yaml``: /btw ok - ``"zero_contour"`` *(default)* — uses ``jax_zero_contour``; no dense grid needed, seed points found automatically via a coarse grid scan. - ``"marching_squares"`` — evaluates eigen values on the full *grid* and uses marching squares. Radial critical curves are only computed when at least one radial critical-curve area exceeds the grid pixel scale. Parameters ---------- mass_obj Any object understood by ``LensCalc.from_mass_obj``. grid : aa.type.Grid2DLike Evaluation grid (used only for the ``"marching_squares"`` path). tc : list or None Pre-computed tangential critical curves; ``None`` to trigger computation. rc : list or None Pre-computed radial critical curves; ``None`` to trigger computation. Returns ------- tuple[list, list or None] ``(tangential_critical_curves, radial_critical_curves)``. """ from autogalaxy.operate.lens_calc import LensCalc if os.environ.get("PYAUTO_FAST_PLOTS") == "1": return [], [] if tc is None: od = LensCalc.from_mass_obj(mass_obj) method = _critical_curves_method() if method == "zero_contour": tc = od.tangential_critical_curve_list_via_zero_contour_from() rc = od.radial_critical_curve_list_via_zero_contour_from() else: tc = od.tangential_critical_curve_list_from(grid=grid) rc_area = od.radial_critical_curve_area_list_from(grid=grid) if any(area > grid.pixel_scale for area in rc_area): rc = od.radial_critical_curve_list_from(grid=grid) return tc, rc