1.1.2.4. autojax.tests module¶
- class autojax.tests.AutoTestMeta(name, bases, namespace)[source]¶
Bases:
type
Metaclass to auto-generate pytest-benchmark tests
- class autojax.tests.Data(pixel_scale: float = 0.2, _centre: float = 0.0, N_: int = 30, coefficient: float = 1.0, B: int = 3)[source]¶
Bases:
object
Test data.
- B: int = 3¶
- property K: int¶
- property M: int¶
- property N: int¶
- N_: int = 30¶
- property N_PRIME: int¶
- property P: int¶
- property S: int¶
- property centre: tuple[float, float]¶
Get the centre of the native grid.
- coefficient: float = 1.0¶
- property curvature_matrix: ndarray[tuple[int, int], float64]¶
- property curvature_reg_matrix: ndarray[tuple[int, int], float64]¶
- property data: ndarray[tuple[int], complex128]¶
- property data_vector: ndarray[tuple[int], float64]¶
- property dirty_image: ndarray[tuple[int], float64]¶
- property grid_radians_2d: ndarray[tuple[int, int, int], float64]¶
- property grid_radians_slim: ndarray[tuple[int, int], float64]¶
- property mapping_matrix: ndarray[tuple[int, int], float64]¶
- property native_index_for_slim_index: ndarray[tuple[int, int], int64]¶
- property neighbors: ndarray[tuple[int], int64]¶
- property neighbors_grid¶
Convert a neighbors array to a grid, primarily for visualization.
- property neighbors_sizes: ndarray[tuple[int], int64]¶
- property noise_map: ndarray[tuple[int], complex128]¶
- property noise_map_real: ndarray[tuple[int], float64]¶
- property pix_indexes_for_sub_slim_index: ndarray[tuple[int, int], int64]¶
- property pix_pixels: int¶
- property pix_size_for_sub_slim_index: ndarray[tuple[int], int64]¶
- property pix_weights_for_sub_slim_index: ndarray[tuple[int, int], float64]¶
- pixel_scale: float = 0.2¶
- property pixel_scales: tuple[float, float]¶
Get the pixel scales of the native grid.
- property radius: float¶
- property real_space_mask: ndarray[tuple[int, int], bool]¶
- property regularization_matrix: ndarray[tuple[int, int], float64]¶
- property shape_masked_pixels_2d: tuple[int, int]¶
Get the shape of the masked grid.
- property shape_native: tuple[int, int]¶
Get the shape of the native grid.
- property slim_index_for_sub_slim_index: ndarray[tuple[int], int64]¶
- property sub_fraction: ndarray[tuple[int], float64]¶
- property uv_wavelengths: ndarray[tuple[int, int], float64]¶
- property visibilities_real: ndarray[tuple[int], float64]¶
- property w_compact: ndarray[tuple[int, int], float64]¶
- property w_tilde: ndarray[tuple[int, int], float64]¶
- property w_tilde_preload: ndarray[tuple[int, int], float64]¶
- class autojax.tests.DataGenerated(pixel_scale: float = 0.2, _centre: float = 0.0, N_: int = 30, coefficient: float = 1.0, B: int = 3, K_: int = 1024, P_: int = 32, S_: int = 256)[source]¶
Bases:
Data
Generate data for testing.
- property K: int¶
- K_: int = 1024¶
- property P: int¶
- P_: int = 32¶
- property S: int¶
- S_: int = 256¶
- property data: ndarray[tuple[int], complex128]¶
Generate random data map.
- property dirty_image: ndarray[tuple[int], float64]¶
Generate a random dirty image.
- property neighbors: ndarray[tuple[int, int], int64]¶
Generate random neighbors.
- property neighbors_sizes: ndarray[tuple[int], int64]¶
- property noise_map: ndarray[tuple[int], complex128]¶
Generate random noise map.
- property pix_indexes_for_sub_slim_index: ndarray[tuple[int, int], int64]¶
- property pix_weights_for_sub_slim_index: ndarray[tuple[int, int], float64]¶
- property uv_wavelengths: ndarray[tuple[int, int], float64]¶
Generate random uv wavelengths.
- class autojax.tests.DataLoaded(pixel_scale: float = 0.2, _centre: float = 0.0, N_: int = 30, coefficient: float = 1.0, B: int = 3, path: Path = PosixPath('/home/runner/work/python-autojax/python-autojax/dataset/data.npz'))[source]¶
Bases:
Data
Load data from file.
- property K: int¶
- property M: int¶
- property P: int¶
- property S: int¶
- property data¶
- property dirty_image¶
- property neighbors¶
- property neighbors_sizes¶
- property noise_map¶
- path: Path = PosixPath('/home/runner/work/python-autojax/python-autojax/dataset/data.npz')¶
- property pix_indexes_for_sub_slim_index¶
- property pix_weights_for_sub_slim_index¶
- property uv_wavelengths¶
- class autojax.tests.Reference(data: Data)[source]¶
Bases:
object
Generate reference values for testing.
- property ref: dict¶
- class autojax.tests.TestCurvatureMatrix[source]¶
Bases:
object
Compute curvature matrix via various methods.
- The input w can be w_tilde, preload, or compact. w_tilde is allowed
to consider that it can be expanded in memory outside the MCMC loop.
- The input mapping_matrix must be its sparse form, such as pix_weights_for_sub_slim_index, …
This is because the mapping matrix has to be generated on the fly anyway, so even the dense form must be generated from the sparse form at some point in the MCMC loop.
- test_curvature_matrix_jax(data_bundle, benchmark)[source]¶
From w_tilde, construct dense mapping matrix.
- test_curvature_matrix_jax_BCOO(data_bundle, benchmark)[source]¶
From w_tilde, construct BCOO mapping matrix.
- test_curvature_matrix_jax_compact_sparse(data_bundle, benchmark)[source]¶
From w_compact, internal sparse mapping matrix.
- test_curvature_matrix_jax_compact_sparse_BCOO(data_bundle, benchmark)[source]¶
From w_compact, left BCOO mapping matrix, right internal sparse mapping matrix.
- test_curvature_matrix_jax_sparse(data_bundle, benchmark)[source]¶
From w_tilde, internal sparse mapping matrix.
- test_curvature_matrix_numba(data_bundle, benchmark)[source]¶
From w_tilde, construct dense mapping matrix.
- test_curvature_matrix_numba_compact_sparse(data_bundle, benchmark)[source]¶
From w_compact, internal sparse mapping matrix, sparse matmul.
- test_curvature_matrix_numba_compact_sparse_direct(data_bundle, benchmark)[source]¶
From w_compact, internal sparse mapping matrix, direct 4-loop matmul.
- test_curvature_matrix_numba_sparse(data_bundle, benchmark)[source]¶
From w_tilde, internal sparse mapping matrix.
- class autojax.tests.TestJax[source]¶
Bases:
object
- mod = <module 'autojax.jax' from '/home/runner/work/python-autojax/python-autojax/src/autojax/jax.py'>¶
- test_constant_regularization_matrix_from_jax(data_bundle, benchmark)¶
- test_curvature_matrix_via_w_compact_sparse_mapping_matrix_from_jax(data_bundle, benchmark)¶
- test_curvature_matrix_via_w_tilde_curvature_preload_interferometer_from_jax(data_bundle, benchmark)¶
- test_curvature_matrix_via_w_tilde_from_jax(data_bundle, benchmark)¶
- test_data_vector_from_jax(data_bundle, benchmark)¶
- test_log_likelihood_function_via_w_compact_from_jax(data_bundle, benchmark)¶
- test_log_likelihood_function_via_w_tilde_from_jax(data_bundle, benchmark)¶
- test_mapping_matrix_from_jax(data_bundle, benchmark)¶
- test_mask_2d_circular_from_jax(data_bundle, benchmark)¶
- test_noise_normalization_complex_from_jax(data_bundle, benchmark)¶
- test_reconstruction_positive_negative_from_jax(data_bundle, benchmark)¶
- test_sparse_mapping_matrix_transpose_matmul_jax(data_bundle, benchmark)¶
- test_w_compact_curvature_interferometer_from_jax(data_bundle, benchmark)¶
- test_w_tilde_curvature_interferometer_from_jax(data_bundle, benchmark)¶
- test_w_tilde_curvature_preload_interferometer_from_jax(data_bundle, benchmark)¶
- test_w_tilde_data_interferometer_from_jax(data_bundle, benchmark)¶
- test_w_tilde_via_compact_from_jax(data_bundle, benchmark)¶
- test_w_tilde_via_preload_from_jax(data_bundle, benchmark)¶
- class autojax.tests.TestNumba[source]¶
Bases:
object
- mod = <module 'autojax.numba' from '/home/runner/work/python-autojax/python-autojax/src/autojax/numba.py'>¶
- test_constant_regularization_matrix_from_numba(data_bundle, benchmark)¶
- test_curvature_matrix_via_w_compact_sparse_mapping_matrix_from_numba(data_bundle, benchmark)¶
- test_curvature_matrix_via_w_tilde_curvature_preload_interferometer_from_numba(data_bundle, benchmark)¶
- test_curvature_matrix_via_w_tilde_from_numba(data_bundle, benchmark)¶
- test_data_vector_from_numba(data_bundle, benchmark)¶
- test_log_likelihood_function_via_w_compact_from_numba(data_bundle, benchmark)¶
- test_log_likelihood_function_via_w_tilde_from_numba(data_bundle, benchmark)¶
- test_mapping_matrix_from_numba(data_bundle, benchmark)¶
- test_mask_2d_circular_from_numba(data_bundle, benchmark)¶
- test_noise_normalization_complex_from_numba(data_bundle, benchmark)¶
- test_reconstruction_positive_negative_from_numba(data_bundle, benchmark)¶
- test_sparse_mapping_matrix_transpose_matmul_numba(data_bundle, benchmark)¶
- test_w_compact_curvature_interferometer_from_numba(data_bundle, benchmark)¶
- test_w_tilde_curvature_interferometer_from_numba(data_bundle, benchmark)¶
- test_w_tilde_curvature_preload_interferometer_from_numba(data_bundle, benchmark)¶
- test_w_tilde_data_interferometer_from_numba(data_bundle, benchmark)¶
- test_w_tilde_via_compact_from_numba(data_bundle, benchmark)¶
- test_w_tilde_via_preload_from_numba(data_bundle, benchmark)¶
- class autojax.tests.TestOriginal[source]¶
Bases:
object
- mod = <module 'autojax.original' from '/home/runner/work/python-autojax/python-autojax/src/autojax/original.py'>¶
- test_constant_regularization_matrix_from_original(data_bundle, benchmark)¶
- test_curvature_matrix_via_w_compact_sparse_mapping_matrix_from_original(data_bundle, benchmark)¶
- test_curvature_matrix_via_w_tilde_curvature_preload_interferometer_from_original(data_bundle, benchmark)¶
- test_curvature_matrix_via_w_tilde_from_original(data_bundle, benchmark)¶
- test_data_vector_from_original(data_bundle, benchmark)¶
- test_log_likelihood_function_via_w_compact_from_original(data_bundle, benchmark)¶
- test_log_likelihood_function_via_w_tilde_from_original(data_bundle, benchmark)¶
- test_mapping_matrix_from_original(data_bundle, benchmark)¶
- test_mask_2d_circular_from_original(data_bundle, benchmark)¶
- test_noise_normalization_complex_from_original(data_bundle, benchmark)¶
- test_reconstruction_positive_negative_from_original(data_bundle, benchmark)¶
- test_sparse_mapping_matrix_transpose_matmul_original(data_bundle, benchmark)¶
- test_w_compact_curvature_interferometer_from_original(data_bundle, benchmark)¶
- test_w_tilde_curvature_interferometer_from_original(data_bundle, benchmark)¶
- test_w_tilde_curvature_preload_interferometer_from_original(data_bundle, benchmark)¶
- test_w_tilde_data_interferometer_from_original(data_bundle, benchmark)¶
- test_w_tilde_via_compact_from_original(data_bundle, benchmark)¶
- test_w_tilde_via_preload_from_original(data_bundle, benchmark)¶
- class autojax.tests.TestWTilde[source]¶
Bases:
object
Compute w_tilde via various methods.
This adds on top of existing benchmarks to compare the performance of the preload method.
The test names are a bit strange, but is designed to be filtered like this:
pytest -k w_tilde_curvature_interferometer_from
This compares
direct computation of
w_tilde
(prefixed by
_compact
/_preload
) computew_tilde
on in the preload/compact form(prefixed by
_expanded
) computew_tilde
as above and then expand fully
(1) and (3) should be compared if
w_tilde
is needed in the full form. (1) and (2) should be compared ifw_tilde
is needed regardless of the form.
- autojax.tests.deterministic_seed(string: str, *numbers: int) int [source]¶
Generate a deterministic seed from the class name.
- autojax.tests.gen_neighbors(S, P, rng) ndarray[tuple[int, int], int64] [source]¶
Generate random neighbors.