Source code for autoarray.operators.transformer

from astropy import units
import copy
import numpy as np
import warnings


class NUFFTPlaceholder:
    pass


class PyLopsPlaceholder:
    pass


try:
    from pynufft.linalg.nufft_cpu import NUFFT_cpu
except ModuleNotFoundError:
    NUFFT_cpu = NUFFTPlaceholder

try:
    import pylops

    PyLopsOperator = pylops.LinearOperator
except ModuleNotFoundError:
    PyLopsOperator = PyLopsPlaceholder

from autoarray.structures.arrays.uniform_2d import Array2D
from autoarray.structures.grids.uniform_2d import Grid2D
from autoarray.structures.visibilities import Visibilities

from autoarray.structures.arrays import array_2d_util
from autoarray.operators import transformer_util


def pynufft_exception():
    raise ModuleNotFoundError(
        "\n--------------------\n"
        "You are attempting to perform interferometer analysis.\n\n"
        "However, the optional library PyNUFFT (https://github.com/jyhmiinlin/pynufft) is not installed.\n\n"
        "Install it via the command `pip install pynufft==2022.2.2`.\n\n"
        "----------------------"
    )


def pylops_exception():
    raise ModuleNotFoundError(
        "\n--------------------\n"
        "You are attempting to perform interferometer analysis.\n\n"
        "However, the optional library PyLops (https://github.com/PyLops/pylops) is not installed.\n\n"
        "Install it via the command `pip install pylops==1.18.3`.\n\n"
        "----------------------"
    )


[docs] class TransformerDFT(PyLopsOperator):
[docs] def __init__(self, uv_wavelengths, real_space_mask, preload_transform=True): if isinstance(self, PyLopsPlaceholder): pylops_exception() super().__init__() self.uv_wavelengths = uv_wavelengths.astype("float") self.real_space_mask = real_space_mask self.grid = self.real_space_mask.derive_grid.unmasked.in_radians self.total_visibilities = uv_wavelengths.shape[0] self.total_image_pixels = self.real_space_mask.pixels_in_mask self.preload_transform = preload_transform if preload_transform: self.preload_real_transforms = transformer_util.preload_real_transforms( grid_radians=np.array(self.grid), uv_wavelengths=self.uv_wavelengths, ) self.preload_imag_transforms = transformer_util.preload_imag_transforms( grid_radians=np.array(self.grid), uv_wavelengths=self.uv_wavelengths, ) self.real_space_pixels = self.real_space_mask.pixels_in_mask self.shape = ( int(np.prod(self.total_visibilities)), int(np.prod(self.real_space_pixels)), ) self.dtype = "complex128" self.explicit = False # NOTE: This is the scaling factor that needs to be applied to the adjoint operator self.adjoint_scaling = (2.0 * self.grid.shape_native[0]) * ( 2.0 * self.grid.shape_native[1] ) self.matvec_count = 0 self.rmatvec_count = 0 self.matmat_count = 0 self.rmatmat_count = 0
def visibilities_from(self, image): if self.preload_transform: visibilities = transformer_util.visibilities_via_preload_jit_from( image_1d=np.array(image), preloaded_reals=self.preload_real_transforms, preloaded_imags=self.preload_imag_transforms, ) else: visibilities = transformer_util.visibilities_jit( image_1d=np.array(image.slim), grid_radians=np.array(self.grid), uv_wavelengths=self.uv_wavelengths, ) return Visibilities(visibilities=visibilities) def image_from(self, visibilities, use_adjoint_scaling: bool = False): image_slim = transformer_util.image_via_jit_from( n_pixels=self.grid.shape[0], grid_radians=np.array(self.grid), uv_wavelengths=self.uv_wavelengths, visibilities=visibilities.in_array, ) image_native = array_2d_util.array_2d_native_from( array_2d_slim=image_slim, mask_2d=self.real_space_mask, ) return Array2D(values=image_native, mask=self.real_space_mask) def transform_mapping_matrix(self, mapping_matrix): if self.preload_transform: return transformer_util.transformed_mapping_matrix_via_preload_jit_from( mapping_matrix=mapping_matrix, preloaded_reals=self.preload_real_transforms, preloaded_imags=self.preload_imag_transforms, ) else: return transformer_util.transformed_mapping_matrix_jit( mapping_matrix=mapping_matrix, grid_radians=np.array(self.grid), uv_wavelengths=self.uv_wavelengths, )
[docs] class TransformerNUFFT(NUFFT_cpu, PyLopsOperator):
[docs] def __init__(self, uv_wavelengths, real_space_mask): if isinstance(self, NUFFTPlaceholder): pynufft_exception() if isinstance(self, PyLopsPlaceholder): pylops_exception() super(TransformerNUFFT, self).__init__() self.uv_wavelengths = uv_wavelengths self.real_space_mask = real_space_mask # self.grid = self.real_space_mask.unmasked_grid.in_radians self.grid = Grid2D.from_mask(mask=self.real_space_mask).in_radians self.native_index_for_slim_index = copy.copy( real_space_mask.derive_indexes.native_for_slim.astype("int") ) # NOTE: The plan need only be initialized once self.initialize_plan() # ... self.shift = np.exp( -2.0 * np.pi * 1j * ( self.grid.pixel_scales[0] / 2.0 * units.arcsec.to(units.rad) * self.uv_wavelengths[:, 1] + self.grid.pixel_scales[0] / 2.0 * units.arcsec.to(units.rad) * self.uv_wavelengths[:, 0] ) ) self.real_space_pixels = self.real_space_mask.pixels_in_mask # NOTE: If reshaped the shape of the operator is (2 x Nvis, Np) else it is (Nvis, Np) self.total_visibilities = int(uv_wavelengths.shape[0] * uv_wavelengths.shape[1]) self.shape = ( int(np.prod(self.total_visibilities)), int(np.prod(self.real_space_pixels)), ) # NOTE: If the operator is reshaped then the output is real. self.dtype = "float64" self.explicit = False # NOTE: This is the scaling factor that needs to be applied to the adjoint operator self.adjoint_scaling = (2.0 * self.grid.shape_native[0]) * ( 2.0 * self.grid.shape_native[1] ) self.matvec_count = 0 self.rmatvec_count = 0 self.matmat_count = 0 self.rmatmat_count = 0
def initialize_plan(self, ratio=2, interp_kernel=(6, 6)): if not isinstance(ratio, int): ratio = int(ratio) # ... NOTE : The u,v coordinated should be given in the order ... visibilities_normalized = np.array( [ self.uv_wavelengths[:, 1] / (1.0 / (2.0 * self.grid.pixel_scales[0] * units.arcsec.to(units.rad))) * np.pi, self.uv_wavelengths[:, 0] / (1.0 / (2.0 * self.grid.pixel_scales[0] * units.arcsec.to(units.rad))) * np.pi, ] ).T # NOTE: self.plan( om=visibilities_normalized, Nd=self.grid.shape_native, Kd=(ratio * self.grid.shape_native[0], ratio * self.grid.shape_native[1]), Jd=interp_kernel, ) def visibilities_from(self, image): """ ... """ warnings.filterwarnings("ignore") return Visibilities( visibilities=self.forward( image.native[::-1, :] ) # flip due to PyNUFFT internal flip ) def image_from(self, visibilities, use_adjoint_scaling: bool = False): with warnings.catch_warnings(): warnings.simplefilter("ignore") image = np.real(self.adjoint(visibilities))[::-1, :] if use_adjoint_scaling: image *= self.adjoint_scaling return Array2D(values=image, mask=self.real_space_mask) def transform_mapping_matrix(self, mapping_matrix): transformed_mapping_matrix = 0 + 0j * np.zeros( (self.uv_wavelengths.shape[0], mapping_matrix.shape[1]) ) for source_pixel_1d_index in range(mapping_matrix.shape[1]): image_2d = array_2d_util.array_2d_native_from( array_2d_slim=mapping_matrix[:, source_pixel_1d_index], mask_2d=self.grid.mask, ) image = Array2D(values=image_2d, mask=self.grid.mask) visibilities = self.visibilities_from(image=image) transformed_mapping_matrix[:, source_pixel_1d_index] = visibilities return transformed_mapping_matrix def forward_lop(self, x): """ Forward NUFFT on CPU :param x: The input numpy array, with the size of Nd or Nd + (batch,) :type: numpy array with the dtype of numpy.complex64 :return: y: The output numpy array, with the size of (M,) or (M, batch) :rtype: numpy array with the dtype of numpy.complex64 """ warnings.filterwarnings("ignore") x2d = array_2d_util.array_2d_native_complex_via_indexes_from( array_2d_slim=x, shape_native=self.real_space_mask.shape_native, native_index_for_slim_index_2d=self.native_index_for_slim_index, )[::-1, :] y = self.k2y(self.xx2k(self.x2xx(x2d))) return np.concatenate((y.real, y.imag), axis=0) def adjoint_lop(self, y): """ Adjoint NUFFT on CPU :param y: The input numpy array, with the size of (M,) or (M, batch) :type: numpy array with the dtype of numpy.complex64 :return: x: The output numpy array, with the size of Nd or Nd + (batch, ) :rtype: numpy array with the dtype of numpy.complex64 """ warnings.filterwarnings("ignore") def a_complex_from(a_real, a_imag): return a_real + 1j * a_imag y = a_complex_from( a_real=y[: int(self.shape[0] / 2.0)], a_imag=y[int(self.shape[0] / 2.0) :] ) x2d = np.real(self.xx2x(self.k2xx(self.y2k(y)))) x = array_2d_util.array_2d_slim_complex_from( array_2d_native=x2d[::-1, :], mask=np.array(self.real_space_mask), ) x = x.real # NOTE: # NOTE: x *= self.adjoint_scaling return x def _matvec(self, x): return self.forward_lop(x) def _rmatvec(self, x): return self.adjoint_lop(x)