Source code for autolens.point.solver.point_solver
"""
Image-plane point-source solver for strong gravitational lensing.
Finding the multiple images of a point source requires solving the lens equation
θ = β + α(θ) for θ given a fixed source-plane position β. This is an inverse
problem with no analytic solution for general mass distributions.
``PointSolver`` solves this numerically using a triangle-tiling approach:
1. The image plane is tiled with triangles.
2. Each triangle is ray-traced to the source plane.
3. Triangles that contain the source-plane coordinate β are refined recursively.
4. The centroids of the final refined triangles give the image-plane positions.
The output positions array is padded to a fixed size (``MAX_CONTAINING_SIZE``) using the
sentinel value ``inf`` for JAX compatibility — these ``inf`` entries are stripped by
default but can be retained for use inside a ``jax.jit``-traced function.
"""
import logging
import os
from typing import Tuple, Optional
import numpy as np
import autoarray as aa
from autoarray.structures.triangles.shape import Point
from autolens.lens.tracer import Tracer
from .shape_solver import AbstractSolver
logger = logging.getLogger(__name__)
[docs]
class PointSolver(AbstractSolver):
[docs]
def solve(
self,
tracer: Tracer,
source_plane_coordinate: Tuple[float, float],
xp=None,
plane_redshift: Optional[float] = None,
remove_infinities: Optional[bool] = None,
) -> aa.Grid2DIrregular:
"""
Solve for the image plane coordinates that are traced to the source plane coordinate.
This is done by tiling the image plane with triangles and checking if the source plane coordinate is contained
within the triangle. The triangles are sub-sampled to increase the resolution with only the triangles that
contain the source plane coordinate and their neighbours being kept.
The means of the triangles are then filtered to keep only those with an absolute magnification above the
threshold.
The positions are stored on an array of fixed shape defined by `MAX_CONTAINING_SIZE`. This ensures the
array is static, which is important for JAX compatibility. This array typically has many entries
which use the sentinel value of `inf`, subsequent JAX calculations incorporated. By default, these
sentinel values are removed from the output, for example general use outside of JAX when simulating
strong lenses.
Parameters
----------
tracer
The tracer that traces the image plane coordinates to the source plane.
source_plane_coordinate
The plane coordinate to trace to the image plane, which by default in the source-plane coordinate
but could be a coordinate in another plane is `plane_redshift` is input.
xp
The array module (``numpy`` or ``jax.numpy``) the solve runs in. ``AnalysisPoint``
passes ``jax.numpy`` when ``use_jax=True`` is set on the analysis. When ``None`` (the
default), falls back to ``self._xp`` — which is ``jnp`` if the solver was constructed
with ``use_jax=True`` and ``np`` otherwise. Pass explicitly to override.
plane_redshift
The redshift of the plane coordinate, which for multi-plane systems may not be the source-plane.
remove_infinities
Whether to strip the ``inf`` sentinel rows from the output. When ``None`` (the default),
defaults to ``True`` on the NumPy path and ``False`` on the JAX path. The JAX path
keeps the padded static shape so the output crosses a ``jax.jit`` boundary cleanly;
strip the infinities outside the jit if needed.
Returns
-------
A ``Grid2DIrregular`` of image-plane coordinates. NumPy-backed on the default path,
``jax.Array``-backed when ``use_jax=True`` (or ``xp=jnp``).
Notes
-----
Smoke-test short-circuit (``PYAUTO_SMALL_DATASETS``): the triangle-tiling solve
is the dominant cost in many simulator scripts and is meaningless on the
downsized grids used for fast smoke tests. When ``PYAUTO_SMALL_DATASETS=1`` is
set the solver returns the fixed pair ``[(1.0, 0.0), (0.0, 1.0)]`` immediately,
skipping ``solve_triangles`` entirely. The two coordinates are well separated
so any downstream ``positions_likelihood_from`` / threshold calculation behaves
normally. ``PYAUTO_SMALL_DATASETS`` is a smoke-test-only flag and is never set
inside a ``jax.jit`` trace, so a plain numpy-backed ``Grid2DIrregular`` is safe
here even when the surrounding analysis uses ``xp=jnp``.
"""
if xp is None:
xp = self._xp
if remove_infinities is None:
remove_infinities = not self.use_jax
# NOTE: pytree registration is the user's responsibility (call
# `autolens.jax.register_tracer_classes(tracer)` once before wrapping
# in @jax.jit). Auto-registering inside solve() doesn't help because
# JAX flattens function arguments at trace time — before entering
# this method — so registration must run before the first jitted
# call. See the `lens_calc.py` workspace guide for the canonical
# JIT-it-yourself pattern.
if os.environ.get("PYAUTO_SMALL_DATASETS") == "1":
return aa.Grid2DIrregular(values=[(1.0, 0.0), (0.0, 1.0)])
kept_triangles = super().solve_triangles(
tracer=tracer,
shape=Point(*source_plane_coordinate),
xp=xp,
plane_redshift=plane_redshift,
)
filtered_means = self._filter_low_magnification(
tracer=tracer, points=kept_triangles.means, xp=xp
)
solution = aa.Grid2DIrregular(
[pair for pair in filtered_means], xp=xp
).array
is_nan = xp.isnan(solution).any(axis=1)
sentinel = xp.full_like(solution[0], fill_value=xp.inf)
solution = xp.where(is_nan[:, None], sentinel, solution)
if remove_infinities:
solution = solution[~xp.isinf(solution).any(axis=1)]
return aa.Grid2DIrregular(solution)