Source code for autogalaxy.operate.lens_calc

"""
`LensCalc` — a calculator that derives all secondary lensing quantities from a deflection callable.

Given any object that exposes `deflections_yx_2d_from` (a `MassProfile`, `Galaxy`, `Galaxies`, or
`Tracer`), `LensCalc` computes:

- **Hessian** (four components: H_yy, H_xy, H_yx, H_xx) — the matrix of second derivatives of the
  lensing potential, computed by finite differences (NumPy) or automatic differentiation (JAX).
- **Convergence via Hessian** — κ = 0.5 (H_yy + H_xx), independent of the analytic profile formula.
- **Shear** — γ₁ = 0.5 (H_xx − H_yy), γ₂ = H_xy.
- **Magnification** — μ = 1 / det(I − H).
- **Critical curves** — image-plane loci where det(I − H) = 0, found via marching squares.
- **Caustics** — source-plane images of the critical curves.
- **Einstein radius** — the effective radius from the area enclosed by the tangential critical curve.
- **Fermat potential** — φ(θ) = ½|θ − β|² − ψ(θ), using the optional potential callable.

The class is constructed with `LensCalc.from_mass_obj(mass)` or `LensCalc.from_tracer(tracer)`.
"""
from functools import wraps
import importlib
import logging
import numpy as np
from typing import List, Tuple, Union

from autoconf import conf

import autoarray as aa

from autogalaxy.util.shear_field import ShearYX2DIrregular

logger = logging.getLogger(__name__)

_OPTIONAL_DEP_WARNED: set = set()


def _maybe_optional_dep_warn(import_name: str, feature_name: str) -> bool:
    """
    Return True (and warn once per process for ``feature_name``) if the
    optional dependency ``import_name`` is not installed; False otherwise.

    Callers that get True must early-return a soft-fail value (NaN, empty
    list, etc.) — a search-killing raise here would discard the post-fit
    metric write of an otherwise-converged search. Mirrors PyAutoLens's
    ``_maybe_magzero_warn`` for the same reason.
    """
    try:
        importlib.import_module(import_name)
        return False
    except ModuleNotFoundError:
        if feature_name not in _OPTIONAL_DEP_WARNED:
            logger.warning(
                "Optional dependency '%s' not installed; '%s' returning "
                "NaN/empty. pip install %s to enable it.",
                import_name, feature_name, import_name,
            )
            _OPTIONAL_DEP_WARNED.add(feature_name)
        return True


def grid_scaled_2d_for_marching_squares_from(
    grid_pixels_2d: aa.Grid2D,
    shape_native: Tuple[int, int],
    mask: aa.Mask2D,
) -> aa.Grid2DIrregular:
    pixel_scales = mask.pixel_scales
    origin = mask.origin

    grid_scaled_1d = aa.util.geometry.grid_scaled_2d_slim_from(
        grid_pixels_2d_slim=grid_pixels_2d,
        shape_native=shape_native,
        pixel_scales=pixel_scales,
        origin=origin,
    )

    grid_scaled_1d[:, 0] -= pixel_scales[0] / 2.0
    grid_scaled_1d[:, 1] += pixel_scales[1] / 2.0

    return aa.Grid2DIrregular(values=grid_scaled_1d)


def evaluation_grid(func):
    @wraps(func)
    def wrapper(
        lensing_obj, grid, pixel_scale: Union[Tuple[float, float], float] = 0.05
    ):
        if hasattr(grid, "is_evaluation_grid"):
            if grid.is_evaluation_grid:
                return func(lensing_obj, grid, pixel_scale)

        pixel_scale_ratio = grid.pixel_scale / pixel_scale

        zoom = aa.Zoom2D(mask=grid.mask)

        zoom_shape_native = zoom.shape_native
        shape_native = (
            int(pixel_scale_ratio * zoom_shape_native[0]),
            int(pixel_scale_ratio * zoom_shape_native[1]),
        )

        max_evaluation_grid_size = conf.instance["general"]["grid"][
            "max_evaluation_grid_size"
        ]

        # This is a hack to prevent the evaluation gird going beyond 1000 x 1000 pixels, which slows the code
        # down a lot. Need a better moe robust way to set this up for any general lens.

        if shape_native[0] > max_evaluation_grid_size:
            pixel_scale = pixel_scale_ratio / (
                shape_native[0] / float(max_evaluation_grid_size)
            )
            shape_native = (max_evaluation_grid_size, max_evaluation_grid_size)

        grid = aa.Grid2D.uniform(
            shape_native=shape_native,
            pixel_scales=(pixel_scale, pixel_scale),
            origin=zoom.offset_scaled,
            respect_small_datasets=False,
        )

        grid.is_evaluation_grid = True

        return func(lensing_obj, grid, pixel_scale)

    return wrapper


[docs] class LensCalc: """ Computes lensing quantities from a deflection-angle callable and an optional potential callable. The deflection callable is used to compute the Hessian, Jacobian, convergence, shear, magnification, critical curves, caustics, and Einstein radius/mass. If a potential callable is also supplied, ``fermat_potential_from`` is available as well. Parameters ---------- deflections_yx_2d_from A callable with signature ``(grid, xp=np, **kwargs)`` that returns the 2D deflection angles on the given grid. Typically a bound method of a ``MassProfile``, ``Galaxy``, or ``Galaxies`` instance. potential_2d_from Optional callable with signature ``(grid, xp=np, **kwargs)`` that returns the 2D lensing potential on the given grid. Required only for ``fermat_potential_from``. """ def __init__(self, deflections_yx_2d_from, potential_2d_from=None): self.deflections_yx_2d_from = deflections_yx_2d_from self.potential_2d_from = potential_2d_from # Maps (kind, pixel_scales, tol, max_newton) -> (f, ZeroSolver). Lets # repeat zero_contour calls reuse JAX's compile cache, since that # cache is keyed on callable identity. self._zero_contour_cache: dict = {}
[docs] @classmethod def from_mass_obj(cls, mass_obj): """Construct from any object that has a ``deflections_yx_2d_from`` method. If the object also exposes ``potential_2d_from``, it is captured so that ``fermat_potential_from`` is available on the returned instance. """ return cls( deflections_yx_2d_from=mass_obj.deflections_yx_2d_from, potential_2d_from=getattr(mass_obj, "potential_2d_from", None), )
[docs] @classmethod def from_tracer( cls, tracer, use_multi_plane: bool = True, plane_i: int = 0, plane_j: int = -1 ): """ Construct from a PyAutoLens ``Tracer`` object. The ``Tracer`` type is deliberately left unannotated: ``autogalaxy`` does not depend on ``autolens``, so no import of ``Tracer`` is performed here. Callers (which live inside ``autolens``) are responsible for passing the correct object. Parameters ---------- tracer A PyAutoLens ``Tracer`` instance. Must expose ``deflections_yx_2d_from`` and, when ``use_multi_plane=True``, ``deflections_between_planes_from``. use_multi_plane If ``True`` the stored callable is ``tracer.deflections_between_planes_from`` with ``plane_i`` and ``plane_j`` bound via ``functools.partial``, matching the multi-plane ray-tracing path. If ``False`` the stored callable is ``tracer.deflections_yx_2d_from``, i.e. the standard two-plane path. plane_i Index of the first plane used by ``deflections_between_planes_from``. Ignored when ``use_multi_plane=False``. Defaults to ``0`` (image plane). plane_j Index of the second plane used by ``deflections_between_planes_from``. Ignored when ``use_multi_plane=False``. Defaults to ``-1`` (source plane). """ potential_2d_from = getattr(tracer, "potential_2d_from", None) if use_multi_plane: from functools import partial return cls( deflections_yx_2d_from=partial( tracer.deflections_between_planes_from, plane_i=plane_i, plane_j=plane_j, ), potential_2d_from=potential_2d_from, ) return cls( deflections_yx_2d_from=tracer.deflections_yx_2d_from, potential_2d_from=potential_2d_from, )
[docs] def time_delay_geometry_term_from(self, grid, xp=np) -> aa.Array2D: """ Returns the geometric time delay term of the Fermat potential for a given grid of image-plane positions. This term is given by: .. math:: \[\tau_{\text{geom}}(\boldsymbol{\theta}) = \frac{1}{2} |\boldsymbol{\theta} - \boldsymbol{\beta}|^2\] where: - \( \boldsymbol{\theta} \) is the image-plane coordinate, - \( \boldsymbol{\beta} = \boldsymbol{\theta} - \boldsymbol{\alpha}(\boldsymbol{\theta}) \) is the source-plane coordinate, - \( \boldsymbol{\alpha} \) is the deflection angle at each image-plane coordinate. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles and time delay geometric term are computed on. Returns ------- The geometric time delay term at each grid position. """ deflections = self.deflections_yx_2d_from(grid=grid, xp=xp) src_y = grid[:, 0] - deflections[:, 0] src_x = grid[:, 1] - deflections[:, 1] delay = 0.5 * ((grid[:, 0] - src_y) ** 2 + (grid[:, 1] - src_x) ** 2) if isinstance(grid, aa.Grid2DIrregular): return aa.ArrayIrregular(values=delay) return aa.Array2D(values=delay, mask=grid.mask)
[docs] def fermat_potential_from(self, grid, xp=np) -> aa.Array2D: """ Returns the Fermat potential for a given grid of image-plane positions. This is the sum of the geometric time delay term and the gravitational (Shapiro) delay term (i.e. the lensing potential), and is given by: .. math:: \\phi(\\boldsymbol{\\theta}) = \\frac{1}{2} |\\boldsymbol{\\theta} - \\boldsymbol{\\beta}|^2 - \\psi(\\boldsymbol{\\theta}) Requires that ``potential_2d_from`` was supplied at construction (e.g. via ``LensCalc.from_mass_obj`` or ``LensCalc.from_tracer``). Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the Fermat potential is computed on. xp The array module (``numpy`` or ``jax.numpy``). """ if self.potential_2d_from is None: raise ValueError( "fermat_potential_from requires a potential_2d_from callable. " "Construct LensCalc with potential_2d_from, or use from_mass_obj / from_tracer." ) time_delay = self.time_delay_geometry_term_from(grid=grid, xp=xp) potential = self.potential_2d_from(grid=grid, xp=xp) fermat_potential = time_delay - potential if isinstance(grid, aa.Grid2DIrregular): return aa.ArrayIrregular(values=fermat_potential) return aa.Array2D(values=fermat_potential, mask=grid.mask)
[docs] def tangential_eigen_value_from(self, grid, xp=np) -> aa.Array2D: """ Returns the tangential eigen values of lensing jacobian, which are given by the expression: `tangential_eigen_value = 1 - convergence - shear` Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles and tangential eigen values are computed on. xp The array module (``numpy`` or ``jax.numpy``). Passed through to ``convergence_2d_via_hessian_from`` and ``shear_yx_2d_via_hessian_from``. When ``xp`` is not ``numpy`` the result is a raw array rather than an ``aa.Array2D`` wrapper. """ convergence = self.convergence_2d_via_hessian_from(grid=grid, xp=xp) shear_yx = self.shear_yx_2d_via_hessian_from(grid=grid, xp=xp) if xp is np: return aa.Array2D( values=1 - convergence - shear_yx.magnitudes, mask=grid.mask ) magnitudes = xp.sqrt(shear_yx[:, 0] ** 2 + shear_yx[:, 1] ** 2) return 1 - convergence - magnitudes
[docs] def radial_eigen_value_from(self, grid, xp=np) -> aa.Array2D: """ Returns the radial eigen values of lensing jacobian, which are given by the expression: `radial_eigen_value = 1 - convergence + shear` Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles and radial eigen values are computed on. xp The array module (``numpy`` or ``jax.numpy``). Passed through to ``convergence_2d_via_hessian_from`` and ``shear_yx_2d_via_hessian_from``. When ``xp`` is not ``numpy`` the result is a raw array rather than an ``aa.Array2D`` wrapper. """ convergence = self.convergence_2d_via_hessian_from(grid=grid, xp=xp) shear = self.shear_yx_2d_via_hessian_from(grid=grid, xp=xp) if xp is np: return aa.Array2D( values=1 - convergence + shear.magnitudes, mask=grid.mask ) magnitudes = xp.sqrt(shear[:, 0] ** 2 + shear[:, 1] ** 2) return 1 - convergence + magnitudes
[docs] def magnification_2d_from(self, grid, xp=np) -> aa.Array2D: """ Returns the 2D magnification map of lensing object, which is computed as the inverse of the determinant of the lensing Jacobian, expressed via the Hessian components. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles and magnification map are computed on. xp The array module (``numpy`` or ``jax.numpy``). Passed through to ``hessian_from``. When ``xp`` is not ``numpy`` the result is a raw array rather than an ``aa.Array2D`` wrapper. """ hessian_yy, hessian_xy, hessian_yx, hessian_xx = self.hessian_from( grid=grid, xp=xp ) det_A = (1 - hessian_xx) * (1 - hessian_yy) - hessian_xy * hessian_yx if xp is np: return aa.Array2D(values=1 / det_A, mask=grid.mask) return 1 / det_A
[docs] def deflections_yx_scalar(self, y, x, pixel_scales): """ Returns the deflection angles at a single (y, x) arc-second coordinate as a JAX array of shape (2,), where index 0 is the y-deflection and index 1 is the x-deflection. This is an internal method used by `hessian_from` to enable JAX auto-differentiation via `jax.jacfwd`. The function must accept y and x as two separate scalar inputs (rather than a single combined array) so that JAX treats the function as R² -> R² and computes a proper 2x2 Jacobian matrix. Parameters ---------- y The y arc-second coordinate (scalar). x The x arc-second coordinate (scalar). pixel_scales The pixel scales used to construct the internal (1, 1) Mask2D. """ import jax.numpy as jnp mask = aa.Mask2D.all_false(shape_native=(1, 1), pixel_scales=pixel_scales) grid = aa.Grid2D( values=jnp.stack((y.reshape(1), x.reshape(1)), axis=-1), mask=mask ) return self.deflections_yx_2d_from(grid, xp=jnp).squeeze()
[docs] def hessian_from(self, grid, xp=np) -> Tuple: """ Returns the Hessian of the lensing object, where the Hessian is the second partial derivatives of the potential (see equation 55 https://inspirehep.net/literature/419263): `hessian_{i,j} = d^2 / dtheta_i dtheta_j` The Hessian is returned as a 4-entry tuple reflecting its structure as a 2x2 matrix: (hessian_yy, hessian_xy, hessian_yx, hessian_xx). Two computational paths are available, selected via the `xp` parameter: - **NumPy** (``xp=np``, default): 2-point central finite-difference approximation, Richardson-extrapolated at step sizes ``h`` and ``h/2`` and combined as ``(4 * H(h/2) - H(h)) / 3``. This cancels the leading ``O(h^2)`` truncation term, giving ``O(h^4)`` accuracy and matching the JAX path to float64 precision. JAX is not imported. - **JAX** (``xp=jnp``): exact derivatives via ``jax.jacfwd`` applied to ``deflections_yx_scalar``, vectorised over the grid with ``jnp.vectorize``. Both paths support uniform ``Grid2D`` and irregular ``Grid2DIrregular`` grids. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the Hessian is computed on. xp The array module (``numpy`` or ``jax.numpy``). Controls which computational path is used and the type of the returned arrays. """ if xp is np: return self._hessian_via_richardson(grid=grid) return self._hessian_via_jax(grid=grid, xp=xp)
def _hessian_via_richardson(self, grid, buffer: float = 0.01) -> Tuple: yy_h, xy_h, yx_h, xx_h = self._hessian_via_finite_difference( grid=grid, buffer=buffer ) yy_h2, xy_h2, yx_h2, xx_h2 = self._hessian_via_finite_difference( grid=grid, buffer=buffer / 2.0 ) hessian_yy = (4.0 * yy_h2 - yy_h) / 3.0 hessian_xy = (4.0 * xy_h2 - xy_h) / 3.0 hessian_yx = (4.0 * yx_h2 - yx_h) / 3.0 hessian_xx = (4.0 * xx_h2 - xx_h) / 3.0 return hessian_yy, hessian_xy, hessian_yx, hessian_xx def _hessian_via_jax(self, grid, xp) -> Tuple: import jax import jax.numpy as jnp pixel_scales = getattr(grid, "pixel_scales", (0.05, 0.05)) y = jnp.array(grid[:, 0]) x = jnp.array(grid[:, 1]) def _hessian_single(y_scalar, x_scalar): return jnp.stack( jax.jacfwd(self.deflections_yx_scalar, argnums=(0, 1))( y_scalar, x_scalar, pixel_scales ) ) h = jnp.vectorize(_hessian_single, signature="(),()->(i,i)")(y, x) # h has shape (N, 2, 2): # h[..., 0, 0] = d(defl_y)/dy = hessian_yy # h[..., 0, 1] = d(defl_x)/dy = hessian_xy # h[..., 1, 0] = d(defl_y)/dx = hessian_yx # h[..., 1, 1] = d(defl_x)/dx = hessian_xx return ( xp.array(h[..., 0, 0]), xp.array(h[..., 0, 1]), xp.array(h[..., 1, 0]), xp.array(h[..., 1, 1]), ) def _hessian_via_finite_difference(self, grid, buffer: float = 0.01) -> Tuple: grid_values = grid.array if hasattr(grid, "array") else grid g0 = np.array(grid_values[:, 0], dtype=np.float64, copy=True) g1 = np.array(grid_values[:, 1], dtype=np.float64, copy=True) grid_shift_y_up = aa.Grid2DIrregular( values=np.stack([g0 + buffer, g1], axis=1) ) grid_shift_y_down = aa.Grid2DIrregular( values=np.stack([g0 - buffer, g1], axis=1) ) grid_shift_x_left = aa.Grid2DIrregular( values=np.stack([g0, g1 - buffer], axis=1) ) grid_shift_x_right = aa.Grid2DIrregular( values=np.stack([g0, g1 + buffer], axis=1) ) deflections_up = self.deflections_yx_2d_from(grid=grid_shift_y_up) deflections_down = self.deflections_yx_2d_from(grid=grid_shift_y_down) deflections_left = self.deflections_yx_2d_from(grid=grid_shift_x_left) deflections_right = self.deflections_yx_2d_from(grid=grid_shift_x_right) hessian_yy = 0.5 * (deflections_up[:, 0] - deflections_down[:, 0]) / buffer hessian_xy = 0.5 * (deflections_up[:, 1] - deflections_down[:, 1]) / buffer hessian_yx = 0.5 * (deflections_right[:, 0] - deflections_left[:, 0]) / buffer hessian_xx = 0.5 * (deflections_right[:, 1] - deflections_left[:, 1]) / buffer return hessian_yy, hessian_xy, hessian_yx, hessian_xx
[docs] def jacobian_from(self, grid, xp=np) -> List: """ Returns the lensing Jacobian of the lensing object as a 2x2 list of lists. The Jacobian is the matrix `A = I - H`, where `H` is the Hessian matrix of the deflection angles: ``A = [[1 - hessian_xx, -hessian_xy], [-hessian_yx, 1 - hessian_yy]]`` It is computed from `hessian_from`, so it supports both uniform and irregular grids and accepts the same `xp` parameter for JAX acceleration. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the Jacobian is computed on. xp The array module (``numpy`` or ``jax.numpy``). Passed through to ``hessian_from``. """ hessian_yy, hessian_xy, hessian_yx, hessian_xx = self.hessian_from( grid=grid, xp=xp ) a11 = 1 - hessian_xx a12 = -hessian_xy a21 = -hessian_yx a22 = 1 - hessian_yy return [[a11, a12], [a21, a22]]
[docs] def convergence_2d_via_hessian_from(self, grid, xp=np) -> aa.ArrayIrregular: """ Returns the convergence of the lensing object, which is computed from the 2D deflection angle map via the Hessian using the expression (see equation 56 https://inspirehep.net/literature/419263): `convergence = 0.5 * (hessian_{0,0} + hessian_{1,1}) = 0.5 * (hessian_xx + hessian_yy)` By going via the Hessian, the convergence can be calculated at any (y,x) coordinate therefore using either a 2D uniform or irregular grid. This calculation of the convergence is independent of analytic calculations defined within `MassProfile` objects and can therefore be used as a cross-check. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles and Hessian are computed on. xp The array module to use for the computation (e.g. `numpy` or `jax.numpy`). Passed through to `hessian_from`. When `xp` is not `numpy` (e.g. inside a `jax.jit` trace) the result is returned as a raw array rather than an `aa.ArrayIrregular` wrapper. """ hessian_yy, hessian_xy, hessian_yx, hessian_xx = self.hessian_from( grid=grid, xp=xp ) convergence = 0.5 * (hessian_yy + hessian_xx) if xp is np: return aa.ArrayIrregular(values=convergence) return convergence
[docs] def shear_yx_2d_via_hessian_from(self, grid, xp=np) -> ShearYX2DIrregular: r""" Returns the 2D weak-lensing shear vector field of the lensing object, computed from the deflection-angle Hessian (see equation 57 of https://inspirehep.net/literature/419263): .. math:: \gamma_1 = \tfrac{1}{2} (\partial^2 \psi / \partial x^2 - \partial^2 \psi / \partial y^2) = 0.5 \, (H_{xx} - H_{yy}) \gamma_2 = \partial^2 \psi / \partial x \partial y = H_{xy} = H_{yx} where :math:`\psi(\theta)` is the lensing potential and :math:`H_{ij}` are its second partial derivatives, i.e. the components of the Hessian returned by ``hessian_from``. The two components together encode the full distortion tensor of a background source: :math:`\gamma_1` is the stretch along the x/y axes and :math:`\gamma_2` is the stretch along the diagonals. Because the Hessian is computed numerically (Richardson-extrapolated finite differences in NumPy, or ``jax.jacfwd`` in JAX), this routine works for **any** ``MassProfile``, ``Galaxy``, ``Galaxies`` or ``Tracer`` that exposes ``deflections_yx_2d_from`` — no per-profile analytic formula is required. It can therefore be used as a numerical cross-check of analytic shear methods such as ``Isothermal.shear_yx_2d_from`` (see ``test_isothermal.py``). It also accepts both uniform ``Grid2D`` and irregular ``Grid2DIrregular`` grids, which makes it the natural choice for simulating a weak-lensing shear field on either a regular pixel grid or at the discrete positions of background source galaxies. Convention ---------- The result is returned as a ``ShearYX2DIrregular`` data structure of shape ``[total_shear_vectors, 2]``, where: - ``[:, 0]`` are the :math:`\gamma_2` values - ``[:, 1]`` are the :math:`\gamma_1` values Note therefore that the FIRST column is :math:`\gamma_2` and the SECOND column is :math:`\gamma_1`. This ordering matches the ``[y, x]`` ordering used elsewhere in the project for vector fields, since :math:`\gamma_2` plays the role of the "y" component and :math:`\gamma_1` the "x" component when the shear field is treated as a 2D pseudo-vector. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles and Hessian are computed on. xp The array module to use for the computation (e.g. ``numpy`` or ``jax.numpy``). Passed through to ``hessian_from``. When ``xp`` is not ``numpy`` (e.g. inside a ``jax.jit`` trace) the result is returned as a raw array of shape ``(N, 2)`` rather than a ``ShearYX2DIrregular`` wrapper, because ``ShearYX2DIrregular`` is not a registered JAX pytree. """ hessian_yy, hessian_xy, hessian_yx, hessian_xx = self.hessian_from( grid=grid, xp=xp ) gamma_1 = 0.5 * (hessian_xx - hessian_yy) gamma_2 = hessian_xy shear_yx_2d = xp.stack([gamma_2, gamma_1], axis=-1) if xp is np: return ShearYX2DIrregular(values=shear_yx_2d, grid=grid) return shear_yx_2d
[docs] def magnification_2d_via_hessian_from(self, grid, xp=np) -> aa.ArrayIrregular: """ Returns the 2D magnification map of lensing object, which is computed from the 2D deflection angle map via the Hessian using the expressions (see equation 60 https://inspirehep.net/literature/419263): `magnification = 1.0 / det(Jacobian) = 1.0 / abs((1.0 - convergence)**2.0 - shear**2.0)` `magnification = (1.0 - hessian_{0,0}) * (1.0 - hessian_{1, 1)) - hessian_{0,1}*hessian_{1,0}` `magnification = (1.0 - hessian_xx) * (1.0 - hessian_yy)) - hessian_xy*hessian_yx` By going via the Hessian, the magnification can be calculated at any (y,x) coordinate, therefore using either a 2D uniform or irregular grid. This calculation of the magnification is independent of calculations using the Jacobian and can therefore be used as a cross-check. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles and magnification map are computed on. """ hessian_yy, hessian_xy, hessian_yx, hessian_xx = self.hessian_from( grid=grid, xp=xp ) det_A = (1 - hessian_xx) * (1 - hessian_yy) - hessian_xy * hessian_yx if xp is np: return aa.ArrayIrregular(values=1.0 / det_A) return 1.0 / det_A
[docs] def contour_list_from(self, grid, contour_array): grid_contour = aa.Grid2DContour( grid=grid, pixel_scales=grid.pixel_scales, shape_native=grid.shape_native, contour_array=contour_array.native, ) return grid_contour.contour_list
[docs] @evaluation_grid def tangential_critical_curve_list_from( self, grid, pixel_scale: Union[Tuple[float, float], float] = 0.05 ) -> List[aa.Grid2DIrregular]: """ Returns all tangential critical curves of the lensing system, which are computed as follows: 1) Compute the tangential eigen values for every coordinate on the input grid via the Jacobian. 2) Find contours of all values in the tangential eigen values that are zero using a marching squares algorithm. Due to the use of a marching squares algorithm that requires the zero values of the tangential eigen values to be computed, critical curves can only be calculated using the Jacobian and a uniform 2D grid. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles and tangential eigen values are computed on. pixel_scale If input, the `evaluation_grid` decorator creates the 2D grid at this resolution, therefore enabling the critical curve to be computed more accurately using a higher resolution grid. """ tangential_eigen_values = self.tangential_eigen_value_from(grid=grid) vals = tangential_eigen_values.native if vals.min() > -1e-6 or vals.max() < 1e-6: return [] return self.contour_list_from(grid=grid, contour_array=tangential_eigen_values)
[docs] @evaluation_grid def radial_critical_curve_list_from( self, grid, pixel_scale: Union[Tuple[float, float], float] = 0.05 ) -> List[aa.Grid2DIrregular]: """ Returns all radial critical curves of the lensing system, which are computed as follows: 1) Compute the radial eigen values for every coordinate on the input grid via the Jacobian. 2) Find contours of all values in the radial eigen values that are zero using a marching squares algorithm. Due to the use of a marching squares algorithm that requires the zero values of the radial eigen values to be computed, this critical curves can only be calculated using the Jacobian and a uniform 2D grid. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles and radial eigen values are computed on. pixel_scale If input, the `evaluation_grid` decorator creates the 2D grid at this resolution, therefore enabling the critical curve to be computed more accurately using a higher resolution grid. """ radial_eigen_values = self.radial_eigen_value_from(grid=grid) vals = radial_eigen_values.native if vals.min() > -1e-6 or vals.max() < 1e-6: return [] return self.contour_list_from(grid=grid, contour_array=radial_eigen_values)
[docs] @evaluation_grid def tangential_caustic_list_from( self, grid, pixel_scale: Union[Tuple[float, float], float] = 0.05 ) -> List[aa.Grid2DIrregular]: """ Returns all tangential caustics of the lensing system, which are computed as follows: 1) Compute the tangential eigen values for every coordinate on the input grid via the Jacobian. 2) Find contours of all values in the tangential eigen values that are zero using a marching squares algorithm. 3) Compute the lensing system's deflection angles at the (y,x) coordinates of the tangential critical curve contours and ray-trace it to the source-plane, therefore forming the tangential caustics. Due to the use of a marching squares algorithm that requires the zero values of the tangential eigen values to be computed, caustics can only be calculated using the Jacobian and a uniform 2D grid. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles and tangential eigen values are computed on. pixel_scale If input, the `evaluation_grid` decorator creates the 2D grid at this resolution, therefore enabling the caustic to be computed more accurately using a higher resolution grid. """ tangential_critical_curve_list = self.tangential_critical_curve_list_from( grid=grid, pixel_scale=pixel_scale ) tangential_caustic_list = [] for tangential_critical_curve in tangential_critical_curve_list: deflections_critical_curve = self.deflections_yx_2d_from( grid=tangential_critical_curve ) tangential_caustic_list.append( tangential_critical_curve - deflections_critical_curve ) return tangential_caustic_list
[docs] @evaluation_grid def radial_caustic_list_from( self, grid, pixel_scale: Union[Tuple[float, float], float] = 0.05 ) -> List[aa.Grid2DIrregular]: """ Returns all radial caustics of the lensing system, which are computed as follows: 1) Compute the radial eigen values for every coordinate on the input grid via the Jacobian. 2) Find contours of all values in the radial eigen values that are zero using a marching squares algorithm. 3) Compute the lensing system's deflection angles at the (y,x) coordinates of the radial critical curve contours and ray-trace it to the source-plane, therefore forming the radial caustics. Due to the use of a marching squares algorithm that requires the zero values of the radial eigen values to be computed, this caustics can only be calculated using the Jacobian and a uniform 2D grid. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles and radial eigen values are computed on. pixel_scale If input, the `evaluation_grid` decorator creates the 2D grid at this resolution, therefore enabling the caustic to be computed more accurately using a higher resolution grid. """ radial_critical_curve_list = self.radial_critical_curve_list_from( grid=grid, pixel_scale=pixel_scale ) radial_caustic_list = [] for radial_critical_curve in radial_critical_curve_list: deflections_critical_curve = self.deflections_yx_2d_from( grid=radial_critical_curve ) radial_caustic_list.append( radial_critical_curve - deflections_critical_curve ) return radial_caustic_list
[docs] @evaluation_grid def radial_critical_curve_area_list_from( self, grid, pixel_scale: Union[Tuple[float, float], float] ) -> List[float]: """ Returns the surface area within each radial critical curve as a list, the calculation of which is described in the function `radial_critical_curve_list_from()`. The area is computed via a line integral. Due to the use of a marching squares algorithm to estimate the critical curve, this function can only use the Jacobian and a uniform 2D grid. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles used to calculate the radial critical curve are computed on. pixel_scale If input, the `evaluation_grid` decorator creates the 2D grid at this resolution, therefore enabling the caustic to be computed more accurately using a higher resolution grid. """ radial_critical_curve_list = self.radial_critical_curve_list_from( grid=grid, pixel_scale=pixel_scale ) return self.area_within_curve_list_from(curve_list=radial_critical_curve_list)
[docs] @evaluation_grid def tangential_critical_curve_area_list_from( self, grid, pixel_scale: Union[Tuple[float, float], float] = 0.05 ) -> List[float]: """ Returns the surface area within each tangential critical curve as a list, the calculation of which is described in the function `tangential_critical_curve_list_from()`. The area is computed via a line integral. Due to the use of a marching squares algorithm to estimate the critical curve, this function can only use the Jacobian and a uniform 2D grid. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles used to calculate the tangential critical curve are computed on. pixel_scale If input, the `evaluation_grid` decorator creates the 2D grid at this resolution, therefore enabling the caustic to be computed more accurately using a higher resolution grid. """ tangential_critical_curve_list = self.tangential_critical_curve_list_from( grid=grid, pixel_scale=pixel_scale ) return self.area_within_curve_list_from( curve_list=tangential_critical_curve_list )
[docs] def area_within_curve_list_from( self, curve_list: List[aa.Grid2DIrregular] ) -> List[float]: area_within_each_curve_list = [] for curve in curve_list: x, y = curve[:, 0], curve[:, 1] area = np.abs(0.5 * np.sum(y[:-1] * np.diff(x) - x[:-1] * np.diff(y))) area_within_each_curve_list.append(area) return area_within_each_curve_list
[docs] @evaluation_grid def einstein_radius_list_from( self, grid, pixel_scale: Union[Tuple[float, float], float] = 0.05 ): """ Returns a list of the Einstein radii corresponding to the area within each tangential critical curve. Each Einstein radius is defined as the radius of the circle which contains the same area as the area within each tangential critical curve. This definition is sometimes referred to as the "effective Einstein radius" in the literature and is commonly adopted in studies, for example the SLACS series of papers. The calculation of the tangential critical curves and their areas is described in the functions `tangential_critical_curve_list_from()` and `tangential_critical_curve_area_list_from()`. Due to the use of a marching squares algorithm to estimate the critical curve, this function can only use the Jacobian and a uniform 2D grid. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles used to calculate the tangential critical curve are computed on. pixel_scale If input, the `evaluation_grid` decorator creates the 2D grid at this resolution, therefore enabling the caustic to be computed more accurately using a higher resolution grid. """ try: area_list = self.tangential_critical_curve_area_list_from( grid=grid, pixel_scale=pixel_scale ) return [np.sqrt(area / np.pi) for area in area_list] except TypeError: raise TypeError("The grid input was unable to estimate the Einstein Radius")
[docs] @evaluation_grid def einstein_radius_from( self, grid, pixel_scale: Union[Tuple[float, float], float] = 0.05 ): """ Returns the Einstein radius corresponding to the area within the tangential critical curve. The Einstein radius is defined as the radius of the circle which contains the same area as the area within the tangential critical curve. This definition is sometimes referred to as the "effective Einstein radius" in the literature and is commonly adopted in studies, for example the SLACS series of papers. If there are multiple tangential critical curves (e.g. because the mass distribution is complex) this function raises an error, and the function `einstein_radius_list_from()` should be used instead. The calculation of the tangential critical curves and their areas is described in the functions `tangential_critical_curve_list_from()` and `tangential_critical_curve_area_list_from()`. Due to the use of a marching squares algorithm to estimate the critical curve, this function can only use the Jacobian and a uniform 2D grid. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles used to calculate the tangential critical curve are computed on. pixel_scale If input, the `evaluation_grid` decorator creates the 2D grid at this resolution, therefore enabling the caustic to be computed more accurately using a higher resolution grid. """ einstein_radii_list = self.einstein_radius_list_from(grid=grid) if len(einstein_radii_list) > 1: logger.info( """ There are multiple tangential critical curves, and the computed Einstein radius is the sum of all of them. Check the `einstein_radius_list_from` function for the individual Einstein. """ ) return sum(einstein_radii_list)
[docs] @evaluation_grid def einstein_mass_angular_list_from( self, grid, pixel_scale: Union[Tuple[float, float], float] = 0.05 ) -> List[float]: """ Returns a list of the angular Einstein massses corresponding to the area within each tangential critical curve. The angular Einstein mass is defined as: `einstein_mass = pi * einstein_radius ** 2.0` where the Einstein radius is the radius of the circle which contains the same area as the area within the tangential critical curve. The Einstein mass is returned in units of arcsecond**2.0 and requires division by the lensing critical surface density \sigma_cr to be converted to physical units like solar masses (see `autogalaxy.util.cosmology_util`). This definition of Eisntein radius (and therefore mass) is sometimes referred to as the "effective Einstein radius" in the literature and is commonly adopted in studies, for example the SLACS series of papers. The calculation of the einstein radius is described in the function `einstein_radius_from()`. Due to the use of a marching squares algorithm to estimate the critical curve, this function can only use the Jacobian and a uniform 2D grid. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles used to calculate the tangential critical curve are computed on. pixel_scale If input, the `evaluation_grid` decorator creates the 2D grid at this resolution, therefore enabling the caustic to be computed more accurately using a higher resolution grid. """ einstein_radius_list = self.einstein_radius_list_from( grid=grid, pixel_scale=pixel_scale ) return [np.pi * einstein_radius**2 for einstein_radius in einstein_radius_list]
[docs] @evaluation_grid def einstein_mass_angular_from( self, grid, pixel_scale: Union[Tuple[float, float], float] = 0.05 ) -> float: """ Returns the Einstein radius corresponding to the area within the tangential critical curve. The angular Einstein mass is defined as: `einstein_mass = pi * einstein_radius ** 2.0` where the Einstein radius is the radius of the circle which contains the same area as the area within the tangential critical curve. The Einstein mass is returned in units of arcsecond**2.0 and requires division by the lensing critical surface density \sigma_cr to be converted to physical units like solar masses (see `autogalaxy.util.cosmology_util`). This definition of Eisntein radius (and therefore mass) is sometimes referred to as the "effective Einstein radius" in the literature and is commonly adopted in studies, for example the SLACS series of papers. The calculation of the einstein radius is described in the function `einstein_radius_from()`. Due to the use of a marching squares algorithm to estimate the critical curve, this function can only use the Jacobian and a uniform 2D grid. Parameters ---------- grid The 2D grid of (y,x) arc-second coordinates the deflection angles used to calculate the tangential critical curve are computed on. pixel_scale If input, the `evaluation_grid` decorator creates the 2D grid at this resolution, therefore enabling the caustic to be computed more accurately using a higher resolution grid. """ einstein_mass_angular_list = self.einstein_mass_angular_list_from( grid=grid, pixel_scale=pixel_scale ) if len(einstein_mass_angular_list) > 1: logger.info( """ There are multiple tangential critical curves, and the computed Einstein mass is the sum of all of them. Check the `einstein_mass_list_from` function for the individual Einstein. """ ) return einstein_mass_angular_list[0]
# ------------------------------------------------------------------------- # jax_zero_contour-based critical curves / caustics # ------------------------------------------------------------------------- def _make_eigen_fn(self, kind: str, pixel_scales=(0.05, 0.05)): """Return a JAX scalar function ``f(pos) -> eigen_value``. ``pos`` has shape ``(2,)`` (y, x) — ``ZeroSolver.newton`` is vmapped over the init_guess rows, passing each row slice individually to ``jax.lax.custom_root(f, ...)``. The function is fully JAX-differentiable: ``ZeroSolver`` calls ``jacfwd``/``jacrev`` on it internally (Newton's method on the eigen value requires the second derivative of the deflections). ``hessian_xy`` is symmetrised as ``0.5 * (H[0,1] + H[1,0])`` to guard against numerically non-curl-free deflection fields. Parameters ---------- kind ``"tangential"`` (eigen value = ``1 - κ - |γ|``) or ``"radial"`` (``1 - κ + |γ|``). pixel_scales Forwarded to ``deflections_yx_scalar`` for its internal single-pixel ``Mask2D``. """ import jax import jax.numpy as jnp try: from jax.tree_util import Partial except ImportError: from functools import partial as Partial # Capture as local names so the closure holds no `self` reference. # ZeroSolver.zero_contour_finder is jit-compiled with `f` as a # non-static argument, so it must be a JAX pytree. Wrapping in # Partial with no dynamic args gives a pytree whose treedef is the # closure itself and whose leaves list is empty. _deflections_yx_scalar = self.deflections_yx_scalar _pixel_scales = pixel_scales _sign = -1.0 if kind == "tangential" else 1.0 def _f(pos): y, x = pos[0], pos[1] H = jnp.stack( jax.jacfwd(_deflections_yx_scalar, argnums=(0, 1))( y, x, _pixel_scales ) ) convergence = 0.5 * (H[0, 0] + H[1, 1]) gamma_1 = 0.5 * (H[1, 1] - H[0, 0]) gamma_2 = 0.5 * (H[0, 1] + H[1, 0]) # symmetrised shear = jnp.sqrt(gamma_1 ** 2 + gamma_2 ** 2) return 1.0 - convergence + _sign * shear return Partial(_f) def _make_tangential_eigen_fn(self, pixel_scales=(0.05, 0.05)): """Return a JAX scalar function ``f(pos) -> tangential_eigen_value``.""" return self._make_eigen_fn(kind="tangential", pixel_scales=pixel_scales) def _make_radial_eigen_fn(self, pixel_scales=(0.05, 0.05)): """Return a JAX scalar function ``f(pos) -> radial_eigen_value``.""" return self._make_eigen_fn(kind="radial", pixel_scales=pixel_scales) def _init_guess_from_coarse_grid( self, kind: str = "tangential", grid_shape: Tuple[int, int] = (25, 25), grid_extent: float = 3.0, ): """Return a rough initial-guess array near the critical curve. Evaluates the eigen values on a very coarse uniform grid (default 25 × 25 = 625 evaluations, versus ~250 000 for the production grid) and runs the existing marching-squares ``contour_list_from`` on it to find seed points — one per distinct curve segment. The midpoint of each coarse segment is taken as the initial guess. Parameters ---------- kind ``"tangential"`` or ``"radial"``. grid_shape Number of pixels along each axis of the coarse evaluation grid. grid_extent Half-width of the coarse grid in arc-seconds. Returns ------- jax.numpy.ndarray of shape ``(n_curves, 2)`` """ import jax.numpy as jnp pixel_scale = 2.0 * grid_extent / grid_shape[0] grid = aa.Grid2D.uniform( shape_native=grid_shape, pixel_scales=(pixel_scale, pixel_scale), ) if kind == "tangential": eigen_values = self.tangential_eigen_value_from(grid=grid) else: eigen_values = self.radial_eigen_value_from(grid=grid) coarse_curves = self.contour_list_from( grid=grid, contour_array=eigen_values ) if not coarse_curves: raise ValueError( f"No {kind} critical curve found within the coarse grid " f"(extent ±{grid_extent} arcsec, shape {grid_shape}). " "Pass an explicit `init_guess` or increase `grid_extent`." ) seeds = [curve[len(curve) // 2] for curve in coarse_curves] return jnp.array(seeds) def _critical_curve_list_via_zero_contour( self, kind: str, init_guess=None, delta: float = 0.05, N: int = 500, pixel_scales: Tuple[float, float] = (0.05, 0.05), tol: float = 1e-6, max_newton: int = 5, ) -> List[aa.Grid2DIrregular]: """Shared implementation for tangential/radial critical curves. Parameters ---------- kind ``"tangential"`` or ``"radial"``. init_guess JAX or NumPy array of shape ``(n, 2)``. ``None`` triggers an automatic coarse-grid seed search. delta Arc-second step size along the contour. N Maximum steps in each direction from each seed. pixel_scales Pixel scales passed to ``deflections_yx_scalar``. tol Newton's method convergence tolerance. max_newton Maximum Newton iterations per step. """ if _maybe_optional_dep_warn( "jax_zero_contour", "critical_curve_list_via_zero_contour" ): return [] from jax_zero_contour import ZeroSolver import jax.numpy as jnp if init_guess is None: try: init_guess = self._init_guess_from_coarse_grid(kind=kind) except ValueError: return [] init_guess = jnp.atleast_2d(jnp.asarray(init_guess)) # Reuse the (f, solver) pair across calls so ZeroSolver hits its JAX # compile cache. A fresh f / solver per call costs ~10s every time # because the cache is keyed on callable identity. cache_key = (kind, pixel_scales, tol, max_newton) if cache_key not in self._zero_contour_cache: self._zero_contour_cache[cache_key] = ( self._make_eigen_fn(kind=kind, pixel_scales=pixel_scales), ZeroSolver(tol=tol, max_newton=max_newton), ) f, solver = self._zero_contour_cache[cache_key] paths, _ = solver.zero_contour_finder(f, init_guess, delta=delta, N=N) paths = ZeroSolver.path_reduce(paths) closed_paths = [] for path in paths["path"]: if len(path) > 1: pts = np.array(path) pts = np.vstack([pts, pts[0]]) # close the curve closed_paths.append(aa.Grid2DIrregular(values=pts)) return closed_paths
[docs] def tangential_critical_curve_list_via_zero_contour_from( self, init_guess=None, delta: float = 0.05, N: int = 500, pixel_scales: Tuple[float, float] = (0.05, 0.05), tol: float = 1e-6, max_newton: int = 5, ) -> List[aa.Grid2DIrregular]: """ Returns tangential critical curves using the ``jax_zero_contour`` package. Unlike ``tangential_critical_curve_list_from``, this method does not evaluate lensing quantities on a dense uniform grid. Instead it traces the zero contour of the tangential eigen value directly, evaluating the function only along the curve itself. The algorithm (from ``jax_zero_contour.ZeroSolver``): 1. Newton's method projects each initial guess onto the exact zero contour of the tangential eigen value. 2. Euler-Lagrange (gradient-perpendicular) stepping traces the contour in both directions from each projected seed point. 3. Tracing stops when the path closes, hits an endpoint, or exhausts ``N`` steps. Parameters ---------- init_guess JAX or NumPy array of shape ``(n, 2)`` with rough ``(y, x)`` positions near the tangential critical curve — one seed per distinct curve. If ``None`` a coarse 25 × 25 grid scan is used to find seed points automatically. delta Arc-second step size along the contour. Smaller values give denser, more accurate curves but require a larger ``N`` to trace the same total length. N Maximum number of steps in each direction from each seed point. The traced path has at most ``2N + 1`` points per seed. pixel_scales Pixel scales passed to ``deflections_yx_scalar`` for its internal single-pixel mask. tol Newton's method convergence tolerance (forwarded to ``ZeroSolver``). max_newton Maximum Newton iterations per step (forwarded to ``ZeroSolver``). Returns ------- List[aa.Grid2DIrregular] One ``Grid2DIrregular`` per traced contour segment, matching the return type of ``tangential_critical_curve_list_from``. """ return self._critical_curve_list_via_zero_contour( kind="tangential", init_guess=init_guess, delta=delta, N=N, pixel_scales=pixel_scales, tol=tol, max_newton=max_newton, )
[docs] def radial_critical_curve_list_via_zero_contour_from( self, init_guess=None, delta: float = 0.05, N: int = 500, pixel_scales: Tuple[float, float] = (0.05, 0.05), tol: float = 1e-6, max_newton: int = 5, ) -> List[aa.Grid2DIrregular]: """ Returns radial critical curves using the ``jax_zero_contour`` package. Identical to ``tangential_critical_curve_list_via_zero_contour_from`` except the zero contour of the *radial* eigen value is traced. Parameters ---------- init_guess JAX or NumPy array of shape ``(n, 2)`` with rough ``(y, x)`` positions near the radial critical curve. If ``None`` a coarse grid scan finds seed points automatically. delta Arc-second step size along the contour. N Maximum number of steps in each direction from each seed. pixel_scales Pixel scales passed to ``deflections_yx_scalar``. tol Newton's method convergence tolerance. max_newton Maximum Newton iterations per step. Returns ------- List[aa.Grid2DIrregular] """ return self._critical_curve_list_via_zero_contour( kind="radial", init_guess=init_guess, delta=delta, N=N, pixel_scales=pixel_scales, tol=tol, max_newton=max_newton, )
def _caustic_list_via_zero_contour( self, kind: str, init_guess=None, delta: float = 0.05, N: int = 500, pixel_scales: Tuple[float, float] = (0.05, 0.05), tol: float = 1e-6, max_newton: int = 5, ) -> List[aa.Grid2DIrregular]: """Shared implementation for tangential/radial caustics.""" cc_list = self._critical_curve_list_via_zero_contour( kind=kind, init_guess=init_guess, delta=delta, N=N, pixel_scales=pixel_scales, tol=tol, max_newton=max_newton, ) return [ cc - self.deflections_yx_2d_from(grid=cc) for cc in cc_list ]
[docs] def tangential_caustic_list_via_zero_contour_from( self, init_guess=None, delta: float = 0.05, N: int = 500, pixel_scales: Tuple[float, float] = (0.05, 0.05), tol: float = 1e-6, max_newton: int = 5, ) -> List[aa.Grid2DIrregular]: """ Returns tangential caustics by ray-tracing the tangential critical curves computed via ``tangential_critical_curve_list_via_zero_contour_from``. Parameters ---------- init_guess Forwarded to ``tangential_critical_curve_list_via_zero_contour_from``. delta Arc-second step size along the contour. N Maximum steps per seed direction. pixel_scales Pixel scales passed to ``deflections_yx_scalar``. tol Newton's method convergence tolerance. max_newton Maximum Newton iterations per step. Returns ------- List[aa.Grid2DIrregular] """ return self._caustic_list_via_zero_contour( kind="tangential", init_guess=init_guess, delta=delta, N=N, pixel_scales=pixel_scales, tol=tol, max_newton=max_newton, )
[docs] def radial_caustic_list_via_zero_contour_from( self, init_guess=None, delta: float = 0.05, N: int = 500, pixel_scales: Tuple[float, float] = (0.05, 0.05), tol: float = 1e-6, max_newton: int = 5, ) -> List[aa.Grid2DIrregular]: """ Returns radial caustics by ray-tracing the radial critical curves computed via ``radial_critical_curve_list_via_zero_contour_from``. Parameters ---------- init_guess Forwarded to ``radial_critical_curve_list_via_zero_contour_from``. delta Arc-second step size along the contour. N Maximum steps per seed direction. pixel_scales Pixel scales passed to ``deflections_yx_scalar``. tol Newton's method convergence tolerance. max_newton Maximum Newton iterations per step. Returns ------- List[aa.Grid2DIrregular] """ return self._caustic_list_via_zero_contour( kind="radial", init_guess=init_guess, delta=delta, N=N, pixel_scales=pixel_scales, tol=tol, max_newton=max_newton, )
[docs] def einstein_radius_list_via_zero_contour_from( self, init_guess=None, delta: float = 0.05, N: int = 500, pixel_scales: Tuple[float, float] = (0.05, 0.05), tol: float = 1e-6, max_newton: int = 5, ) -> List[float]: """ Returns a list of Einstein radii from the tangential critical curves traced via ``tangential_critical_curve_list_via_zero_contour_from``. Each Einstein radius is the radius of the circle with the same area as the corresponding tangential critical curve. Parameters ---------- init_guess Forwarded to ``tangential_critical_curve_list_via_zero_contour_from``. delta Arc-second step size along the contour. N Maximum steps per seed direction. pixel_scales Pixel scales passed to ``deflections_yx_scalar``. tol Newton's method convergence tolerance. max_newton Maximum Newton iterations per step. Returns ------- List[float] """ tangential_critical_curve_list = ( self.tangential_critical_curve_list_via_zero_contour_from( init_guess=init_guess, delta=delta, N=N, pixel_scales=pixel_scales, tol=tol, max_newton=max_newton, ) ) area_list = self.area_within_curve_list_from( curve_list=tangential_critical_curve_list ) return [np.sqrt(area / np.pi) for area in area_list]
[docs] def einstein_radius_via_zero_contour_from( self, init_guess=None, delta: float = 0.05, N: int = 500, pixel_scales: Tuple[float, float] = (0.05, 0.05), tol: float = 1e-6, max_newton: int = 5, ) -> float: """ Returns the Einstein radius from the tangential critical curve traced via ``jax_zero_contour``. If there are multiple tangential critical curves the radii are summed, consistent with ``einstein_radius_from``. Parameters ---------- init_guess Forwarded to ``einstein_radius_list_via_zero_contour_from``. delta Arc-second step size along the contour. N Maximum steps per seed direction. pixel_scales Pixel scales passed to ``deflections_yx_scalar``. tol Newton's method convergence tolerance. max_newton Maximum Newton iterations per step. Returns ------- float """ return sum( self.einstein_radius_list_via_zero_contour_from( init_guess=init_guess, delta=delta, N=N, pixel_scales=pixel_scales, tol=tol, max_newton=max_newton, ) )
[docs] def einstein_radius_jit_from( self, init_guess, delta: float = 0.05, N: int = 500, pixel_scales: Tuple[float, float] = (0.05, 0.05), tol: float = 1e-6, max_newton: int = 5, ): """ JIT-friendly Einstein radius from the tangential critical curve. Designed for use inside ``jax.jit(...)`` (e.g. the per-sample latent loop in ``autofit.Analysis.compute_latent_samples`` with ``LATENT_BATCH_MODE='jit'``). Unlike ``einstein_radius_via_zero_contour_from``, this method: - Requires an explicit ``init_guess`` — no marching-squares seed search (which would require ``skimage`` and break the JAX trace). - Skips ``ZeroSolver.path_reduce`` (variable-length output, not jit-able). - Computes the enclosed area directly from the raw NaN-padded paths array via a JAX-vectorised shoelace formula. - Returns a single scalar ``jax.Array`` (not a Python ``float``). Do NOT combine this with ``jax.vmap`` — the underlying ``jax_zero_contour.ZeroSolver.zero_contour_finder`` uses ``jax.lax.cond`` / ``jax.lax.while_loop`` for early termination, which upstream explicitly documents as incompatible with ``jax.vmap``. For batched evaluation, wrap the caller in ``jax.jit`` and loop over samples in Python (see ``Analysis.LATENT_BATCH_MODE='jit'``). Parameters ---------- init_guess JAX or NumPy array of shape ``(n_seeds, 2)`` — seed positions (y, x) near the expected critical curve. For typical galaxy-scale lenses centred on the image, a single seed at ``[[1.0, 0.0]]`` works; for clusters or off-centre lenses pass a small fan of seeds (e.g. four at cardinal positions). delta Arc-second step size along the contour. Forwarded to ``ZeroSolver``. N Maximum steps per seed direction. Forwarded to ``ZeroSolver``. pixel_scales Pixel scales passed to ``deflections_yx_scalar`` (used internally by ``_make_eigen_fn``). tol Newton's method convergence tolerance. max_newton Maximum Newton iterations per step. Returns ------- jax.Array Scalar Einstein radius in arc-seconds (``sqrt(area / pi)`` where ``area`` is the largest enclosed area across all seeds — robust to multiple seeds landing on the same critical curve). """ if _maybe_optional_dep_warn( "jax_zero_contour", "einstein_radius_jit_from" ): return float("nan") from jax_zero_contour import ZeroSolver import jax.numpy as jnp init_guess = jnp.atleast_2d(jnp.asarray(init_guess)) # Reuse the (f, ZeroSolver) pair from PR #434's cache so the warm # call inside a jit-cached `compute_latent_variables` reuses JAX's # compile cache. cache_key = ("tangential", pixel_scales, tol, max_newton) if cache_key not in self._zero_contour_cache: self._zero_contour_cache[cache_key] = ( self._make_eigen_fn(kind="tangential", pixel_scales=pixel_scales), ZeroSolver(tol=tol, max_newton=max_newton), ) f, solver = self._zero_contour_cache[cache_key] # `paths["path"]` has shape (n_seeds, max_steps, 2) — NaN-padded for # invalid / post-termination steps. We compute the shoelace area # directly on the raw array, masking NaN-padded terms to zero. paths, _ = solver.zero_contour_finder(f, init_guess, delta=delta, N=N) paths_arr = paths["path"] y = paths_arr[..., 0] x = paths_arr[..., 1] y_next = jnp.roll(y, -1, axis=-1) x_next = jnp.roll(x, -1, axis=-1) terms = x * y_next - x_next * y valid_terms = jnp.where(jnp.isfinite(terms), terms, 0.0) area_per_seed = 0.5 * jnp.abs(jnp.sum(valid_terms, axis=-1)) # If multiple seeds landed on the same curve, each computes the same # area; take the max so we don't double-count nor under-count when # some seeds diverged. area = jnp.max(area_per_seed) return jnp.sqrt(area / jnp.pi)