Skip to content

Commit

Permalink
Merge pull request #269 from Jammy2211/feature/sky_simplify
Browse files Browse the repository at this point in the history
Feature/sky simplify
  • Loading branch information
Jammy2211 committed May 14, 2024
2 parents abec2ed + 2a48038 commit 80c7145
Show file tree
Hide file tree
Showing 14 changed files with 52 additions and 77 deletions.
1 change: 1 addition & 0 deletions autolens/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from autoarray.dataset.interferometer.dataset import (
Interferometer,
)
from autoarray.dataset.dataset_model import DatasetModel
from autoarray.mask.mask_1d import Mask1D
from autoarray.mask.mask_2d import Mask2D
from autoarray.operators.convolver import Convolver
Expand Down
8 changes: 6 additions & 2 deletions autolens/aggregator/fit_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from autolens.analysis.preloads import Preloads

from autogalaxy.aggregator.imaging import _imaging_from
from autogalaxy.aggregator.dataset_model import _dataset_model_from
from autogalaxy.aggregator import agg_util

from autolens.aggregator.tracer import _tracer_from
Expand Down Expand Up @@ -58,6 +59,8 @@ def _fit_imaging_from(

tracer_list = _tracer_from(fit=fit, instance=instance)

dataset_model_list = _dataset_model_from(fit=fit, instance=instance)

adapt_images_list = agg_util.adapt_images_from(fit=fit)

settings_inversion = settings_inversion or fit.value(name="settings_inversion")
Expand All @@ -68,8 +71,8 @@ def _fit_imaging_from(

fit_dataset_list = []

for dataset, tracer, adapt_images, mesh_grids_of_planes in zip(
dataset_list, tracer_list, adapt_images_list, mesh_grids_of_planes_list
for dataset, tracer, dataset_model, adapt_images, mesh_grids_of_planes in zip(
dataset_list, tracer_list, dataset_model_list, adapt_images_list, mesh_grids_of_planes_list
):
preloads = agg_util.preloads_from(
preloads_cls=Preloads,
Expand All @@ -82,6 +85,7 @@ def _fit_imaging_from(
FitImaging(
dataset=dataset,
tracer=tracer,
dataset_model=dataset_model,
adapt_images=adapt_images,
settings_inversion=settings_inversion,
preloads=preloads,
Expand Down
7 changes: 5 additions & 2 deletions autolens/aggregator/fit_interferometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import autoarray as aa

from autogalaxy.aggregator.interferometer import _interferometer_from
from autogalaxy.aggregator.dataset_model import _dataset_model_from

from autolens.interferometer.fit_interferometer import FitInterferometer
from autolens.analysis.preloads import Preloads
Expand Down Expand Up @@ -60,6 +61,7 @@ def _fit_interferometer_from(
real_space_mask=real_space_mask,
)
tracer_list = _tracer_from(fit=fit, instance=instance)
dataset_model_list = _dataset_model_from(fit=fit, instance=instance)

adapt_images_list = agg_util.adapt_images_from(fit=fit)

Expand All @@ -71,8 +73,8 @@ def _fit_interferometer_from(

fit_dataset_list = []

for dataset, tracer, adapt_images, mesh_grids_of_planes in zip(
dataset_list, tracer_list, adapt_images_list, mesh_grids_of_planes_list
for dataset, tracer, dataset_model, adapt_images, mesh_grids_of_planes in zip(
dataset_list, tracer_list, dataset_model_list, adapt_images_list, mesh_grids_of_planes_list
):
preloads = agg_util.preloads_from(
preloads_cls=Preloads,
Expand All @@ -85,6 +87,7 @@ def _fit_interferometer_from(
FitInterferometer(
dataset=dataset,
tracer=tracer,
dataset_model=dataset_model,
adapt_images=adapt_images,
settings_inversion=settings_inversion,
preloads=preloads,
Expand Down
25 changes: 9 additions & 16 deletions autolens/imaging/fit_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
self,
dataset: aa.Imaging,
tracer: Tracer,
sky: Optional[ag.LightProfile] = None,
dataset_model : Optional[aa.DatasetModel] = None,
adapt_images: Optional[ag.AdaptImages] = None,
settings_inversion: aa.SettingsInversion = aa.SettingsInversion(),
preloads: Preloads = Preloads(),
Expand Down Expand Up @@ -57,8 +57,8 @@ def __init__(
The imaging dataset which is fitted by the galaxies in the tracer.
tracer
The tracer of galaxies whose light profile images are used to fit the imaging data.
sky
Model component used to represent the background sky emission in an image (e.g. a `Sky` light profile).
dataset_model
Attributes which allow for parts of a dataset to be treated as a model (e.g. the background sky level).
adapt_images
Contains the adapt-images which are used to make a pixelization's mesh and regularization adapt to the
reconstructed galaxy's morphology.
Expand All @@ -72,12 +72,11 @@ def __init__(
decorator take to run.
"""

super().__init__(dataset=dataset, run_time_dict=run_time_dict)
super().__init__(dataset=dataset, dataset_model=dataset_model, run_time_dict=run_time_dict)
AbstractFitInversion.__init__(
self=self, model_obj=tracer, sky=sky, settings_inversion=settings_inversion
self=self, model_obj=tracer, settings_inversion=settings_inversion
)

self.sky = sky
self.tracer = tracer

self.adapt_images = adapt_images
Expand All @@ -96,16 +95,11 @@ def blurred_image(self) -> aa.Array2D:

if self.preloads.blurred_image is None:

if isinstance(self.sky, ag.lp.Sky):
image = self.sky.image_2d_from(grid=self.dataset.grid)
else:
image = np.zeros(self.dataset.shape_slim)

return self.tracer.blurred_image_2d_from(
grid=self.dataset.grid,
convolver=self.dataset.convolver,
blurring_grid=self.dataset.blurring_grid,
) + image
)

return self.preloads.blurred_image

Expand All @@ -115,7 +109,7 @@ def profile_subtracted_image(self) -> aa.Array2D:
Returns the dataset's image with all blurred light profile images in the fit's tracer subtracted.
"""

return self.image - self.blurred_image
return self.data - self.blurred_image

@property
def tracer_to_inversion(self) -> TracerToInversion:
Expand All @@ -135,7 +129,6 @@ def tracer_to_inversion(self) -> TracerToInversion:
return TracerToInversion(
dataset=dataset,
tracer=self.tracer,
sky=self.sky,
adapt_images=self.adapt_images,
settings_inversion=self.settings_inversion,
preloads=self.preloads,
Expand Down Expand Up @@ -302,7 +295,7 @@ def subtracted_images_of_planes_list(self) -> List[aa.Array2D]:
if i != galaxy_index
]

subtracted_image = self.image - sum(other_planes_model_images)
subtracted_image = self.data - sum(other_planes_model_images)

subtracted_images_of_planes_list.append(subtracted_image)

Expand Down Expand Up @@ -384,7 +377,7 @@ def refit_with_new_preloads(
return FitImaging(
dataset=self.dataset,
tracer=self.tracer,
sky=self.sky,
dataset_model=self.dataset_model,
adapt_images=self.adapt_images,
settings_inversion=settings_inversion,
preloads=preloads,
Expand Down
4 changes: 2 additions & 2 deletions autolens/imaging/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def fit_from(
instance=instance, run_time_dict=run_time_dict
)

sky = self.sky_via_instance_from(instance=instance)
dataset_model = self.dataset_model_via_instance_from(instance=instance)

adapt_images = self.adapt_images_via_instance_from(instance=instance)

Expand All @@ -168,7 +168,7 @@ def fit_from(
return FitImaging(
dataset=self.dataset,
tracer=tracer,
sky=sky,
dataset_model=dataset_model,
adapt_images=adapt_images,
settings_inversion=self.settings_inversion,
preloads=preloads,
Expand Down
14 changes: 9 additions & 5 deletions autolens/interferometer/fit_interferometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
self,
dataset: aa.Interferometer,
tracer: Tracer,
dataset_model: Optional[aa.DatasetModel] = None,
adapt_images: Optional[ag.AdaptImages] = None,
settings_inversion: aa.SettingsInversion = aa.SettingsInversion(),
preloads: Preloads = Preloads(),
Expand Down Expand Up @@ -54,6 +55,8 @@ def __init__(
The interforometer dataset which is fitted by the galaxies in the tracer.
tracer
The tracer of galaxies whose light profile images are used to fit the interferometer data.
dataset_model
Attributes which allow for parts of a dataset to be treated as a model (e.g. the background sky level).
adapt_images
Contains the adapt-images which are used to make a pixelization's mesh and regularization adapt to the
reconstructed galaxy's morphology.
Expand Down Expand Up @@ -82,9 +85,9 @@ def __init__(

self.run_time_dict = run_time_dict

super().__init__(dataset=dataset, run_time_dict=run_time_dict)
super().__init__(dataset=dataset, dataset_model=dataset_model, run_time_dict=run_time_dict)
AbstractFitInversion.__init__(
self=self, model_obj=tracer, sky=None, settings_inversion=settings_inversion
self=self, model_obj=tracer, settings_inversion=settings_inversion
)

@property
Expand All @@ -103,7 +106,7 @@ def profile_subtracted_visibilities(self) -> aa.Visibilities:
Returns the interferometer dataset's visibilities with all transformed light profile images in the fit's
tracer subtracted.
"""
return self.visibilities - self.profile_visibilities
return self.data - self.profile_visibilities

@property
def tracer_to_inversion(self) -> TracerToInversion:
Expand Down Expand Up @@ -219,7 +222,7 @@ def model_visibilities_of_planes_list(self) -> List[aa.Visibilities]:
galaxy_model_visibilities_dict = self.galaxy_model_visibilities_dict

model_visibilities_of_planes_list = [
aa.Visibilities.zeros(shape_slim=(self.dataset.visibilities.shape_slim,))
aa.Visibilities.zeros(shape_slim=(self.dataset.data.shape_slim,))
for i in range(self.tracer.total_planes)
]

Expand Down Expand Up @@ -274,8 +277,9 @@ def refit_with_new_preloads(
settings_inversion = self.settings_inversion

return FitInterferometer(
dataset=self.interferometer,
dataset=self.dataset,
tracer=self.tracer,
dataset_model=self.dataset_model,
adapt_images=self.adapt_images,
settings_inversion=settings_inversion,
preloads=preloads,
Expand Down
2 changes: 1 addition & 1 deletion autolens/interferometer/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def profile_log_likelihood_function(
instance=instance,
)

info_dict["number_of_visibilities"] = self.dataset.visibilities.shape[0]
info_dict["number_of_visibilities"] = self.dataset.data.shape[0]
info_dict["transformer_cls"] = self.dataset.transformer.__class__.__name__

self.output_profiling_info(
Expand Down
5 changes: 0 additions & 5 deletions autolens/lens/to_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __init__(
self,
dataset: Optional[Union[aa.Imaging, aa.Interferometer, aa.DatasetInterface]],
tracer,
sky: Optional[Basis] = None,
adapt_images: Optional[ag.AdaptImages] = None,
settings_inversion: aa.SettingsInversion = aa.SettingsInversion(),
preloads=Preloads(),
Expand All @@ -26,7 +25,6 @@ def __init__(

super().__init__(
dataset=dataset,
sky=sky,
adapt_images=adapt_images,
settings_inversion=settings_inversion,
preloads=preloads,
Expand Down Expand Up @@ -106,7 +104,6 @@ def lp_linear_func_list_galaxy_dict(
galaxies_to_inversion = ag.GalaxiesToInversion(
dataset=dataset,
galaxies=galaxies,
sky=self.sky,
settings_inversion=self.settings_inversion,
adapt_images=self.adapt_images,
run_time_dict=self.run_time_dict,
Expand Down Expand Up @@ -168,7 +165,6 @@ def image_plane_mesh_grid_pg_list(self) -> List[List]:
to_inversion = ag.GalaxiesToInversion(
dataset=self.dataset,
galaxies=galaxies,
sky=self.sky,
adapt_images=self.adapt_images,
settings_inversion=self.settings_inversion,
run_time_dict=self.run_time_dict,
Expand Down Expand Up @@ -245,7 +241,6 @@ def mapper_galaxy_dict(self) -> Dict[aa.AbstractMapper, ag.Galaxy]:
to_inversion = ag.GalaxiesToInversion(
dataset=self.dataset,
galaxies=galaxies,
sky=self.sky,
preloads=self.preloads,
adapt_images=self.adapt_images,
settings_inversion=self.settings_inversion,
Expand Down
3 changes: 0 additions & 3 deletions autolens/point/fit_point/positions_image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from functools import partial
from typing import Optional

import autoarray as aa
import autogalaxy as ag

from autolens.point.point_dataset import PointDict
from autolens.point.point_dataset import PointDataset
from autolens.point.point_solver import PointSolver
from autolens.lens.tracer import Tracer

Expand Down
5 changes: 5 additions & 0 deletions test_autolens/aggregator/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ def aggregator_from(database_file, analysis, model, samples):

@pytest.fixture(name="model")
def make_model():

dataset_model = af.Model(al.DatasetModel)
dataset_model.background_sky_level = af.UniformPrior(lower_limit=0.5, upper_limit=1.5)

return af.Collection(
dataset_model=dataset_model,
galaxies=af.Collection(
lens=af.Model(al.Galaxy, redshift=0.5, light=al.lp.Sersic),
source=af.Model(al.Galaxy, redshift=1.0, light=al.lp.Sersic),
Expand Down
2 changes: 2 additions & 0 deletions test_autolens/aggregator/test_aggregator_fit_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def test__fit_imaging_randomly_drawn_via_pdf_gen_from(
assert fit_list[0].tracer.galaxies[0].light.centre == (10.0, 10.0)
assert fit_list[0].tracer.galaxies[1].redshift == 1.0

assert fit_list[0].dataset_model.background_sky_level == 10.0

assert i == 2

clean(database_file=database_file)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def test__fit_interferometer_randomly_drawn_via_pdf_gen_from(
assert fit_list[0].tracer.galaxies[0].light.centre == (10.0, 10.0)
assert fit_list[0].tracer.galaxies[1].redshift == 1.0

assert fit_list[0].dataset_model.background_sky_level == 10.0

assert i == 2

clean(database_file=database_file)
Expand Down

0 comments on commit 80c7145

Please sign in to comment.