1.1.2.4. autojax.tests module

class autojax.tests.AutoTestMeta(name, bases, namespace)[source]

Bases: type

Metaclass to auto-generate pytest-benchmark tests

class autojax.tests.Data(pixel_scale: float = 0.2, _centre: float = 0.0, N_: int = 30, coefficient: float = 1.0, B: int = 3)[source]

Bases: object

Test data.

B: int = 3
property K: int
property M: int
property N: int
N_: int = 30
property N_PRIME: int
property P: int
property S: int
property centre: tuple[float, float]

Get the centre of the native grid.

coefficient: float = 1.0
property curvature_matrix: ndarray[tuple[int, int], float64]
property curvature_reg_matrix: ndarray[tuple[int, int], float64]
property data: ndarray[tuple[int], complex128]
property data_vector: ndarray[tuple[int], float64]
dict() dict[source]
property dirty_image: ndarray[tuple[int], float64]
property grid_radians_2d: ndarray[tuple[int, int, int], float64]
property grid_radians_slim: ndarray[tuple[int, int], float64]
property mapping_matrix: ndarray[tuple[int, int], float64]
property native_index_for_slim_index: ndarray[tuple[int, int], int64]
property neighbors: ndarray[tuple[int], int64]
property neighbors_grid

Convert a neighbors array to a grid, primarily for visualization.

property neighbors_sizes: ndarray[tuple[int], int64]
property noise_map: ndarray[tuple[int], complex128]
property noise_map_real: ndarray[tuple[int], float64]
property pix_indexes_for_sub_slim_index: ndarray[tuple[int, int], int64]
property pix_pixels: int
property pix_size_for_sub_slim_index: ndarray[tuple[int], int64]
property pix_weights_for_sub_slim_index: ndarray[tuple[int, int], float64]
pixel_scale: float = 0.2
property pixel_scales: tuple[float, float]

Get the pixel scales of the native grid.

property radius: float
property real_space_mask: ndarray[tuple[int, int], bool]
property regularization_matrix: ndarray[tuple[int, int], float64]
property shape_masked_pixels_2d: tuple[int, int]

Get the shape of the masked grid.

property shape_native: tuple[int, int]

Get the shape of the native grid.

property slim_index_for_sub_slim_index: ndarray[tuple[int], int64]
property sub_fraction: ndarray[tuple[int], float64]
property uv_wavelengths: ndarray[tuple[int, int], float64]
property visibilities_real: ndarray[tuple[int], float64]
property w_compact: ndarray[tuple[int, int], float64]
property w_tilde: ndarray[tuple[int, int], float64]
property w_tilde_preload: ndarray[tuple[int, int], float64]
class autojax.tests.DataGenerated(pixel_scale: float = 0.2, _centre: float = 0.0, N_: int = 30, coefficient: float = 1.0, B: int = 3, K_: int = 1024, P_: int = 32, S_: int = 256)[source]

Bases: Data

Generate data for testing.

property K: int
K_: int = 1024
property P: int
P_: int = 32
property S: int
S_: int = 256
property data: ndarray[tuple[int], complex128]

Generate random data map.

property dirty_image: ndarray[tuple[int], float64]

Generate a random dirty image.

property neighbors: ndarray[tuple[int, int], int64]

Generate random neighbors.

property neighbors_sizes: ndarray[tuple[int], int64]
property noise_map: ndarray[tuple[int], complex128]

Generate random noise map.

property pix_indexes_for_sub_slim_index: ndarray[tuple[int, int], int64]
property pix_weights_for_sub_slim_index: ndarray[tuple[int, int], float64]
property uv_wavelengths: ndarray[tuple[int, int], float64]

Generate random uv wavelengths.

class autojax.tests.DataLoaded(pixel_scale: float = 0.2, _centre: float = 0.0, N_: int = 30, coefficient: float = 1.0, B: int = 3, path: Path = PosixPath('/home/runner/work/python-autojax/python-autojax/dataset/data.npz'))[source]

Bases: Data

Load data from file.

property K: int
property M: int
property P: int
property S: int
property data
property dirty_image
property neighbors
property neighbors_sizes
property noise_map
path: Path = PosixPath('/home/runner/work/python-autojax/python-autojax/dataset/data.npz')
property pix_indexes_for_sub_slim_index
property pix_weights_for_sub_slim_index
property uv_wavelengths
class autojax.tests.Reference(data: Data)[source]

Bases: object

Generate reference values for testing.

data: Data
property ref: dict
class autojax.tests.TestCurvatureMatrix[source]

Bases: object

Compute curvature matrix via various methods.

The input w can be w_tilde, preload, or compact. w_tilde is allowed

to consider that it can be expanded in memory outside the MCMC loop.

The input mapping_matrix must be its sparse form, such as pix_weights_for_sub_slim_index, …

This is because the mapping matrix has to be generated on the fly anyway, so even the dense form must be generated from the sparse form at some point in the MCMC loop.

test_curvature_matrix_jax(data_bundle, benchmark)[source]

From w_tilde, construct dense mapping matrix.

test_curvature_matrix_jax_BCOO(data_bundle, benchmark)[source]

From w_tilde, construct BCOO mapping matrix.

test_curvature_matrix_jax_compact_sparse(data_bundle, benchmark)[source]

From w_compact, internal sparse mapping matrix.

test_curvature_matrix_jax_compact_sparse_BCOO(data_bundle, benchmark)[source]

From w_compact, left BCOO mapping matrix, right internal sparse mapping matrix.

test_curvature_matrix_jax_sparse(data_bundle, benchmark)[source]

From w_tilde, internal sparse mapping matrix.

test_curvature_matrix_numba(data_bundle, benchmark)[source]

From w_tilde, construct dense mapping matrix.

test_curvature_matrix_numba_compact_sparse(data_bundle, benchmark)[source]

From w_compact, internal sparse mapping matrix, sparse matmul.

test_curvature_matrix_numba_compact_sparse_direct(data_bundle, benchmark)[source]

From w_compact, internal sparse mapping matrix, direct 4-loop matmul.

test_curvature_matrix_numba_sparse(data_bundle, benchmark)[source]

From w_tilde, internal sparse mapping matrix.

test_curvature_matrix_original(data_bundle, benchmark)[source]

From w_tilde, construct dense mapping matrix.

test_curvature_matrix_original_preload_direct(data_bundle, benchmark)[source]

From w-preload, internal sparse mapping matrix.

class autojax.tests.TestJax[source]

Bases: object

mod = <module 'autojax.jax' from '/home/runner/work/python-autojax/python-autojax/src/autojax/jax.py'>
test_constant_regularization_matrix_from_jax(data_bundle, benchmark)
test_curvature_matrix_via_w_compact_sparse_mapping_matrix_from_jax(data_bundle, benchmark)
test_curvature_matrix_via_w_tilde_curvature_preload_interferometer_from_jax(data_bundle, benchmark)
test_curvature_matrix_via_w_tilde_from_jax(data_bundle, benchmark)
test_data_vector_from_jax(data_bundle, benchmark)
test_log_likelihood_function_via_w_compact_from_jax(data_bundle, benchmark)
test_log_likelihood_function_via_w_tilde_from_jax(data_bundle, benchmark)
test_mapping_matrix_from_jax(data_bundle, benchmark)
test_mask_2d_circular_from_jax(data_bundle, benchmark)
test_noise_normalization_complex_from_jax(data_bundle, benchmark)
test_reconstruction_positive_negative_from_jax(data_bundle, benchmark)
test_sparse_mapping_matrix_transpose_matmul_jax(data_bundle, benchmark)
test_w_compact_curvature_interferometer_from_jax(data_bundle, benchmark)
test_w_tilde_curvature_interferometer_from_jax(data_bundle, benchmark)
test_w_tilde_curvature_preload_interferometer_from_jax(data_bundle, benchmark)
test_w_tilde_data_interferometer_from_jax(data_bundle, benchmark)
test_w_tilde_via_compact_from_jax(data_bundle, benchmark)
test_w_tilde_via_preload_from_jax(data_bundle, benchmark)
class autojax.tests.TestNumba[source]

Bases: object

mod = <module 'autojax.numba' from '/home/runner/work/python-autojax/python-autojax/src/autojax/numba.py'>
test_constant_regularization_matrix_from_numba(data_bundle, benchmark)
test_curvature_matrix_via_w_compact_sparse_mapping_matrix_from_numba(data_bundle, benchmark)
test_curvature_matrix_via_w_tilde_curvature_preload_interferometer_from_numba(data_bundle, benchmark)
test_curvature_matrix_via_w_tilde_from_numba(data_bundle, benchmark)
test_data_vector_from_numba(data_bundle, benchmark)
test_log_likelihood_function_via_w_compact_from_numba(data_bundle, benchmark)
test_log_likelihood_function_via_w_tilde_from_numba(data_bundle, benchmark)
test_mapping_matrix_from_numba(data_bundle, benchmark)
test_mask_2d_circular_from_numba(data_bundle, benchmark)
test_noise_normalization_complex_from_numba(data_bundle, benchmark)
test_reconstruction_positive_negative_from_numba(data_bundle, benchmark)
test_sparse_mapping_matrix_transpose_matmul_numba(data_bundle, benchmark)
test_w_compact_curvature_interferometer_from_numba(data_bundle, benchmark)
test_w_tilde_curvature_interferometer_from_numba(data_bundle, benchmark)
test_w_tilde_curvature_preload_interferometer_from_numba(data_bundle, benchmark)
test_w_tilde_data_interferometer_from_numba(data_bundle, benchmark)
test_w_tilde_via_compact_from_numba(data_bundle, benchmark)
test_w_tilde_via_preload_from_numba(data_bundle, benchmark)
class autojax.tests.TestOriginal[source]

Bases: object

mod = <module 'autojax.original' from '/home/runner/work/python-autojax/python-autojax/src/autojax/original.py'>
test_constant_regularization_matrix_from_original(data_bundle, benchmark)
test_curvature_matrix_via_w_compact_sparse_mapping_matrix_from_original(data_bundle, benchmark)
test_curvature_matrix_via_w_tilde_curvature_preload_interferometer_from_original(data_bundle, benchmark)
test_curvature_matrix_via_w_tilde_from_original(data_bundle, benchmark)
test_data_vector_from_original(data_bundle, benchmark)
test_log_likelihood_function_via_w_compact_from_original(data_bundle, benchmark)
test_log_likelihood_function_via_w_tilde_from_original(data_bundle, benchmark)
test_mapping_matrix_from_original(data_bundle, benchmark)
test_mask_2d_circular_from_original(data_bundle, benchmark)
test_noise_normalization_complex_from_original(data_bundle, benchmark)
test_reconstruction_positive_negative_from_original(data_bundle, benchmark)
test_sparse_mapping_matrix_transpose_matmul_original(data_bundle, benchmark)
test_w_compact_curvature_interferometer_from_original(data_bundle, benchmark)
test_w_tilde_curvature_interferometer_from_original(data_bundle, benchmark)
test_w_tilde_curvature_preload_interferometer_from_original(data_bundle, benchmark)
test_w_tilde_data_interferometer_from_original(data_bundle, benchmark)
test_w_tilde_via_compact_from_original(data_bundle, benchmark)
test_w_tilde_via_preload_from_original(data_bundle, benchmark)
class autojax.tests.TestWTilde[source]

Bases: object

Compute w_tilde via various methods.

This adds on top of existing benchmarks to compare the performance of the preload method.

The test names are a bit strange, but is designed to be filtered like this:

pytest -k w_tilde_curvature_interferometer_from

This compares

  1. direct computation of w_tilde

  2. (prefixed by _compact/_preload) compute w_tilde on in the preload/compact form

  3. (prefixed by _expanded) compute w_tilde as above and then expand fully

(1) and (3) should be compared if w_tilde is needed in the full form. (1) and (2) should be compared if w_tilde is needed regardless of the form.

test_w_tilde_curvature_interferometer_from_jax_compact(data_bundle, benchmark)[source]
test_w_tilde_curvature_interferometer_from_jax_compact_expanded(data_bundle, benchmark)[source]
test_w_tilde_curvature_interferometer_from_numba_compact(data_bundle, benchmark)[source]
test_w_tilde_curvature_interferometer_from_numba_compact_expanded(data_bundle, benchmark)[source]
test_w_tilde_curvature_interferometer_from_original_preload(data_bundle, benchmark)[source]
test_w_tilde_curvature_interferometer_from_original_preload_expanded(data_bundle, benchmark)[source]
autojax.tests.data_bundle(request)[source]
autojax.tests.deterministic_seed(string: str, *numbers: int) int[source]

Generate a deterministic seed from the class name.

autojax.tests.gen_neighbors(S, P, rng) ndarray[tuple[int, int], int64][source]

Generate random neighbors.

autojax.tests.gen_pix_indexes_for_sub_slim_index(M: int, S: int, B: int) ndarray[tuple[int, int], int64][source]
autojax.tests.get_run(func, data_dict, jax=False)[source]
autojax.tests.get_run_composed_from(func1, func2, data_dict, jax=False)[source]
autojax.tests.get_run_composed_from_prepend(func1, func2, data_dict, jax=False)[source]
autojax.tests.neighbors_grid(neighbors: ndarray[tuple[int, int], int64]) ndarray[tuple[int, int], bool][source]