Source code for autolens.imaging.model.analysis
"""
Analysis class for fitting a ``Tracer`` lens model to an imaging dataset.
``AnalysisImaging`` implements the ``log_likelihood_function`` that a ``PyAutoFit``
non-linear search calls on each iteration. It:
1. Constructs a ``Tracer`` from the current model instance.
2. Optionally applies adaptive galaxy images to linear components.
3. Calls ``FitImaging`` to evaluate the log likelihood.
4. Returns the figure of merit (log likelihood or log evidence).
It also manages result output (``ResultImaging``), on-the-fly visualisation
(``VisualizerImaging``), and position-based priors via ``PositionLikelihood``.
"""
import logging
import autofit as af
import autogalaxy as ag
from autolens.analysis.analysis.dataset import AnalysisDataset
from autolens.imaging.model.result import ResultImaging
from autolens.imaging.model.visualizer import VisualizerImaging
from autolens.imaging.fit_imaging import FitImaging
logger = logging.getLogger(__name__)
logger.setLevel(level="INFO")
[docs]
class AnalysisImaging(AnalysisDataset):
Result = ResultImaging
Visualizer = VisualizerImaging
[docs]
def log_likelihood_function(self, instance: af.ModelInstance) -> float:
"""
Given an instance of the model, where the model parameters are set via a non-linear search, fit the model
instance to the imaging dataset.
This function returns a log likelihood which is used by the non-linear search to guide the model-fit.
For this analysis class, this function performs the following steps:
1) If the analysis has a adapt image, associated the model galaxy images of this dataset to the galaxies in
the model instance.
2) Extract attributes which model aspects of the data reductions, like the scaling the background sky
and background noise.
3) Extracts all galaxies from the model instance and set up a `Tracer`, which includes ordering the galaxies
by redshift to set up each `Plane`.
4) Use the `Tracer` and other attributes to create a `FitImaging` object, which performs steps such as creating
model images of every galaxy in the tracer, blurring them with the imaging dataset's PSF and computing
residuals, a chi-squared statistic and the log likelihood.
Certain models will fail to fit the dataset and raise an exception. For example if an `Inversion` is used, the
linear algebra calculation may be invalid and raise an Exception. In such circumstances the model is discarded
and its likelihood value is passed to the non-linear search in a way that it ignores it (for example, using a
value of -1.0e99).
Parameters
----------
instance
An instance of the model that is being fitted to the data by this analysis (whose parameters have been set
via a non-linear search).
Returns
-------
float
The log likelihood indicating how well this model instance fitted the imaging data.
"""
log_likelihood_penalty = self.log_likelihood_penalty_from(
instance=instance,
)
if self._use_jax:
return self.fit_from(instance=instance).figure_of_merit - log_likelihood_penalty
try:
return self.fit_from(instance=instance).log_likelihood - log_likelihood_penalty
except Exception as e:
raise af.exc.FitException
[docs]
def fit_from(
self,
instance: af.ModelInstance,
) -> FitImaging:
"""
Given a model instance create a `FitImaging` object.
This function is used in the `log_likelihood_function` to fit the model to the imaging data and compute the
log likelihood.
Parameters
----------
instance
An instance of the model that is being fitted to the data by this analysis (whose parameters have been set
via a non-linear search).
check_positions
Whether the multiple image positions of the lensed source should be checked, i.e. whether they trace
within the position threshold of one another in the source plane.
Returns
-------
FitImaging
The fit of the plane to the imaging dataset, which includes the log likelihood.
"""
if self._use_jax:
self._register_fit_imaging_pytrees()
tracer = self.tracer_via_instance_from(
instance=instance,
)
dataset_model = self.dataset_model_via_instance_from(instance=instance)
adapt_images = self.adapt_images_via_instance_from(
instance=instance, galaxies=tracer.galaxies
)
return FitImaging(
dataset=self.dataset,
tracer=tracer,
dataset_model=dataset_model,
adapt_images=adapt_images,
settings=self.settings,
xp=self._xp
)
@staticmethod
def _register_fit_imaging_pytrees() -> None:
"""Register every type reachable from a ``FitImaging`` return value
so ``jax.jit(fit_from)`` can flatten its output.
``dataset``, ``adapt_images`` and ``settings`` are constants per
analysis — ride as aux so JAX does not recurse into them. Everything
else (``tracer``, ``dataset_model`` and the autoarray wrappers they
carry) is dynamic per fit.
"""
from autoarray.abstract_ndarray import register_instance_pytree
from autoarray.dataset.dataset_model import DatasetModel
from autolens.lens.tracer import Tracer
register_instance_pytree(
FitImaging,
no_flatten=("dataset", "adapt_images", "settings"),
)
register_instance_pytree(DatasetModel)
# ``cosmology`` is a fixed physical constant per fit; ride as aux.
register_instance_pytree(Tracer, no_flatten=("cosmology",))