from __future__ import annotations
import numpy as np
from typing import Optional
from autoarray.structures.grids.uniform_2d import Grid2D
from autoarray.inversion.pixelization.image_mesh.abstract_weighted import (
AbstractImageMeshWeighted,
)
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
from autoarray.inversion.pixelization.image_mesh.hilbert import image_and_grid_from
from autoarray.inversion.pixelization.image_mesh.hilbert import (
inverse_transform_sampling_interpolated,
)
from autoarray import exc
[docs]class HilbertBalanced(AbstractImageMeshWeighted):
def __init__(
self,
pixels=10.0,
weight_floor=0.0,
weight_power=0.0,
ratio=0.8,
):
"""
Computes a balanced image-mesh by computing the Hilbert curve of the adapt data and drawing points from it.
The standard `Hilbert` image-mesh suffers a systematic where the vast majority of points are drawn from
the high weighted reigons. This often leaves few points to reconstruct the lower weight regions, leading to
discontinuities in the reconstruction.
This image-mesh addresses this by drawing half the points from the weight map and the other half from
(1 - weight map). This ensures both high and low weighted regions are sampled equally, but still has sufficient
flexibility to dedicate many points to the highest weighted regions.
This requires an adapt-image, which is the image that the Hilbert curve algorithm adapts to in order to compute
the image mesh. This could simply be the image itself, or a model fit to the image which removes certain
features or noise.
For example, using the adapt image, the image mesh is computed as follows:
1) Convert the adapt image to a weight map, which is a 2D array of weight values.
2) Run the Hilbert algorithm on the weight map, such that the image mesh pixels cluster around the weight map
values with higher values.
Parameters
----------
pixels
The total number of pixels in the image mesh and drawn from the Hilbert curve.
weight_floor
The minimum weight value in the weight map, which allows more pixels to be drawn from the lower weight
regions of the adapt image.
weight_power
The power the weight values are raised too, which allows more pixels to be drawn from the higher weight
regions of the adapt image.
ratio
The ratio between the number of pixdels in the image mesh drawn from the weight map and (1 - weight map).
For example, if there are 1000 pixels and a ratio of 0.8, 800 pixels are drawn from the weight map
and therefore make up the high weighted region of the image mesh.
"""
super().__init__(
pixels=pixels,
weight_floor=weight_floor,
weight_power=weight_power,
)
self.ratio = ratio
[docs] def image_plane_mesh_grid_from(
self, grid: Grid2D, adapt_data: Optional[np.ndarray], settings=None
) -> Grid2DIrregular:
"""
Returns an image mesh by running the balanced Hilbert curve on the weight map.
See the `__init__` docstring for a full description of how this is performed.
Parameters
----------
grid
The grid of (y,x) coordinates of the image data the pixelization fits, which the Hilbert curve adapts to.
adapt_data
The weights defining the regions of the image the Hilbert curve adapts to.
Returns
-------
"""
if not grid.mask.is_circular:
raise exc.PixelizationException(
"""
Hilbert image-mesh has been called but the input grid does not use a circular mask.
Ensure that analysis is using a circular mask via the Mask2D.circular classmethod.
"""
)
adapt_data_hb, grid_hb = image_and_grid_from(
image=adapt_data,
mask=grid.mask,
mask_radius=grid.mask.circular_radius,
pixel_scales=grid.mask.pixel_scales,
hilbert_length=193,
)
weight_map = self.weight_map_from(adapt_data=adapt_data_hb)
weight_map_background = 1.0 - weight_map
weight_map /= np.sum(weight_map)
weight_map_background /= np.sum(weight_map_background)
pixels = int(self.pixels * self.ratio)
(
drawn_id,
drawn_x,
drawn_y,
) = inverse_transform_sampling_interpolated(
probabilities=weight_map,
n_samples=pixels,
gridx=grid_hb[:, 1],
gridy=grid_hb[:, 0],
)
grid = np.stack((drawn_y, drawn_x), axis=-1)
(
drawn_id,
drawn_x,
drawn_y,
) = inverse_transform_sampling_interpolated(
probabilities=weight_map_background,
n_samples=(self.pixels - pixels) + 1,
gridx=grid_hb[:, 1],
gridy=grid_hb[:, 0],
)
grid_background = np.stack((drawn_y, drawn_x), axis=-1)
return Grid2DIrregular(
values=np.concatenate((grid, grid_background[1:, :]), axis=0)
)