Source code for autolens.point.model.analysis

"""
Analysis class for fitting a ``Tracer`` model to a point-source dataset.

``AnalysisPoint`` implements the ``log_likelihood_function`` called by a ``PyAutoFit``
non-linear search at each iteration.  It:

1. Constructs a ``Tracer`` from the current model instance.
2. Calls ``FitPointDataset`` to fit the point-source positions (and optionally fluxes
   and time delays) using the ``PointSolver`` to find predicted image positions.
3. Optionally adds a position-based prior via ``PositionsLH`` that penalises models
   where image positions are not self-consistent.
4. Returns the total log likelihood as the figure of merit.

It also manages result output (``ResultPoint``) and on-the-fly visualisation
(``VisualizerPoint``).
"""
import numpy as np

import autofit as af
import autogalaxy as ag

from autogalaxy.analysis.analysis.analysis import Analysis as AgAnalysis

from autolens.analysis.analysis.lens import AnalysisLens
from autolens.point.fit.positions.image.pair_repeat import FitPositionsImagePairRepeat
from autolens.point.fit.dataset import FitPointDataset
from autolens.point.dataset import PointDataset
from autolens.point.model.result import ResultPoint
from autolens.point.model.visualizer import VisualizerPoint
from autolens.point.solver import PointSolver


[docs] class AnalysisPoint(AgAnalysis, AnalysisLens): Visualizer = VisualizerPoint Result = ResultPoint def __init__( self, dataset: PointDataset, solver: PointSolver, fit_positions_cls=FitPositionsImagePairRepeat, image=None, cosmology: ag.cosmo.LensingCosmology = None, title_prefix: str = None, use_jax: bool = True, **kwargs, ): """ Fits a lens model to a point source dataset (e.g. positions, fluxes, time delays) via a non-linear search. The `Analysis` class defines the `log_likelihood_function` which fits the model to the dataset and returns the log likelihood value defining how well the model fitted the data. It handles many other tasks, such as visualization, outputting results to hard-disk and storing results in a format that can be loaded after the model-fit is complete. This class is used for model-fits which fit lens models to point datasets, which may include some combination of positions, fluxes and time-delays. This class stores the settings used to perform the model-fit for certain components of the model (e.g. a pixelization or inversion), the Cosmology used for the analysis and adapt images used for certain model classes. Parameters ---------- dataset The `PointDataset` that is fitted by the model, which contains a combination of positions, fluxes and time-delays. solver Solves the lens equation in order to determine the image-plane positions of a point source by ray-tracing triangles to and from the source-plane. fit_positions_cls The class used to fit the positions of the point source dataset, which could be an image-plane or source-plane chi-squared. cosmology The Cosmology assumed for this analysis. title_prefix A string that is added before the title of all figures output by visualization, for example to put the name of the dataset and galaxy in the title. """ super().__init__(cosmology=cosmology, use_jax=use_jax, **kwargs) AnalysisLens.__init__(self=self, cosmology=cosmology, use_jax=use_jax) self.dataset = dataset self.solver = solver self.fit_positions_cls = fit_positions_cls self.title_prefix = title_prefix
[docs] def log_likelihood_function(self, instance): """ Given an instance of the model, where the model parameters are set via a non-linear search, fit the model instance to the point source dataset. This function returns a log likelihood which is used by the non-linear search to guide the model-fit. For this analysis class, this function performs the following steps: 1) Extracts all galaxies from the model instance and set up a `Tracer`, which includes ordering the galaxies by redshift to set up each `Plane`. 2) Use the `Tracer` and other attributes to create a `FitPointDataset` object, which performs the steps below to fit different components of the point source dataset. 3) If the point source dataset has positions and model fits positions, perform this fit and compute the log likelihood. This calculation uses the `fit_positions_cls` object, which may be an image-plane or source-plane chi-squared. 4) If the point source dataset has fluxes and model fits fluxes, perform this fit and compute the log likelihood. 5) If the point source dataset has time-delays and model fits time-delays, perform this fit and compute the log likelihood [NOT SUPPORTED YET]. 6) Sum the log likelihoods of the positions, fluxes and time-delays (if they are fitted) to get the overall log likelihood of the model. Certain models will fail to fit the dataset and raise an exception. For example for ill defined mass models the `PointSolver` may find no solution. In such circumstances the model is discarded and its likelihood value is passed to the non-linear search in a way that it ignores it (for example, using a value of -1.0e99). Parameters ---------- instance An instance of the model that is being fitted to the data by this analysis (whose parameters have been set via a non-linear search). Returns ------- float The log likelihood indicating how well this model instance fitted the imaging data. """ return self.fit_from(instance=instance).log_likelihood
[docs] def fit_from( self, instance, ) -> FitPointDataset: """ Given a model instance create a `FitPointDataset` object. This function is used in the `log_likelihood_function` to fit the model to the imaging data and compute the log likelihood. Parameters ---------- instance An instance of the model that is being fitted to the data by this analysis (whose parameters have been set via a non-linear search). Returns ------- The fit of the lens model to the point source dataset. """ if self._use_jax: self._register_fit_point_pytrees() tracer = self.tracer_via_instance_from( instance=instance, ) return FitPointDataset( dataset=self.dataset, tracer=tracer, solver=self.solver, fit_positions_cls=self.fit_positions_cls, xp=self._xp, )
@staticmethod def _register_fit_point_pytrees() -> None: """Register every type reachable from a ``FitPointDataset`` return value so ``jax.jit(fit_from)`` can flatten its output. ``dataset`` and ``solver`` are constants per analysis — ride as aux so JAX does not recurse into them. ``fit_positions_cls`` is a class reference (not a value) so must also ride as aux. ``tracer`` is dynamic per fit. """ from autoarray.abstract_ndarray import register_instance_pytree from autolens.lens.tracer import Tracer from autolens.point.fit.positions.image.pair_all import FitPositionsImagePairAll from autolens.point.fit.positions.image.pair_repeat import FitPositionsImagePairRepeat from autolens.point.fit.positions.image.pair import FitPositionsImagePair import autogalaxy as ag register_instance_pytree( FitPointDataset, no_flatten=("dataset", "solver", "fit_positions_cls"), ) register_instance_pytree(Tracer, no_flatten=("cosmology",)) # fit-point-pytree: observed data/noise are per-analysis constants; solver/name/use_jax are non-JAX register_instance_pytree( FitPositionsImagePairAll, no_flatten=("solver", "name", "use_jax", "_data", "_noise_map"), ) # fit-point-pytree register_instance_pytree( FitPositionsImagePairRepeat, no_flatten=("solver", "name", "use_jax", "_data", "_noise_map"), ) # fit-point-pytree register_instance_pytree( FitPositionsImagePair, no_flatten=("solver", "name", "use_jax", "_data", "_noise_map"), ) # fit-point-pytree: ag.ps.Point / PointFlux are handled by # autofit.jax.pytrees.register_model before jit is called; skip here.
[docs] def save_attributes(self, paths: af.DirectoryPaths): """ Before the non-linear search begins, this routine saves attributes of the `Analysis` object to the `files` folder such that they can be loaded after the analysis using PyAutoFit's database and aggregator tools. For this analysis, it uses the `AnalysisDataset` object's method to output the following: - The dataset's point source dataset as a readable .json file. It is common for these attributes to be loaded by many of the template aggregator functions given in the `aggregator` modules. For example, when using the database tools to perform a fit, the default behaviour is for the dataset, settings and other attributes necessary to perform the fit to be loaded via the pickle files output by this function. Parameters ---------- paths The paths object which manages all paths, e.g. where the non-linear search outputs are stored, visualization, and the pickled objects used by the aggregator output by this function. """ ag.output_to_json( obj=self.dataset, file_path=paths._files_path / "dataset.json", )