Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sky simplify #269

Merged
merged 6 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions autolens/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from autoarray.dataset.interferometer.dataset import (
Interferometer,
)
from autoarray.dataset.dataset_model import DatasetModel
from autoarray.mask.mask_1d import Mask1D
from autoarray.mask.mask_2d import Mask2D
from autoarray.operators.convolver import Convolver
Expand Down
8 changes: 6 additions & 2 deletions autolens/aggregator/fit_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from autolens.analysis.preloads import Preloads

from autogalaxy.aggregator.imaging import _imaging_from
from autogalaxy.aggregator.dataset_model import _dataset_model_from
from autogalaxy.aggregator import agg_util

from autolens.aggregator.tracer import _tracer_from
Expand Down Expand Up @@ -58,6 +59,8 @@ def _fit_imaging_from(

tracer_list = _tracer_from(fit=fit, instance=instance)

dataset_model_list = _dataset_model_from(fit=fit, instance=instance)

adapt_images_list = agg_util.adapt_images_from(fit=fit)

settings_inversion = settings_inversion or fit.value(name="settings_inversion")
Expand All @@ -68,8 +71,8 @@ def _fit_imaging_from(

fit_dataset_list = []

for dataset, tracer, adapt_images, mesh_grids_of_planes in zip(
dataset_list, tracer_list, adapt_images_list, mesh_grids_of_planes_list
for dataset, tracer, dataset_model, adapt_images, mesh_grids_of_planes in zip(
dataset_list, tracer_list, dataset_model_list, adapt_images_list, mesh_grids_of_planes_list
):
preloads = agg_util.preloads_from(
preloads_cls=Preloads,
Expand All @@ -82,6 +85,7 @@ def _fit_imaging_from(
FitImaging(
dataset=dataset,
tracer=tracer,
dataset_model=dataset_model,
adapt_images=adapt_images,
settings_inversion=settings_inversion,
preloads=preloads,
Expand Down
7 changes: 5 additions & 2 deletions autolens/aggregator/fit_interferometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import autoarray as aa

from autogalaxy.aggregator.interferometer import _interferometer_from
from autogalaxy.aggregator.dataset_model import _dataset_model_from

from autolens.interferometer.fit_interferometer import FitInterferometer
from autolens.analysis.preloads import Preloads
Expand Down Expand Up @@ -60,6 +61,7 @@ def _fit_interferometer_from(
real_space_mask=real_space_mask,
)
tracer_list = _tracer_from(fit=fit, instance=instance)
dataset_model_list = _dataset_model_from(fit=fit, instance=instance)

adapt_images_list = agg_util.adapt_images_from(fit=fit)

Expand All @@ -71,8 +73,8 @@ def _fit_interferometer_from(

fit_dataset_list = []

for dataset, tracer, adapt_images, mesh_grids_of_planes in zip(
dataset_list, tracer_list, adapt_images_list, mesh_grids_of_planes_list
for dataset, tracer, dataset_model, adapt_images, mesh_grids_of_planes in zip(
dataset_list, tracer_list, dataset_model_list, adapt_images_list, mesh_grids_of_planes_list
):
preloads = agg_util.preloads_from(
preloads_cls=Preloads,
Expand All @@ -85,6 +87,7 @@ def _fit_interferometer_from(
FitInterferometer(
dataset=dataset,
tracer=tracer,
dataset_model=dataset_model,
adapt_images=adapt_images,
settings_inversion=settings_inversion,
preloads=preloads,
Expand Down
25 changes: 9 additions & 16 deletions autolens/imaging/fit_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
self,
dataset: aa.Imaging,
tracer: Tracer,
sky: Optional[ag.LightProfile] = None,
dataset_model : Optional[aa.DatasetModel] = None,
adapt_images: Optional[ag.AdaptImages] = None,
settings_inversion: aa.SettingsInversion = aa.SettingsInversion(),
preloads: Preloads = Preloads(),
Expand Down Expand Up @@ -57,8 +57,8 @@ def __init__(
The imaging dataset which is fitted by the galaxies in the tracer.
tracer
The tracer of galaxies whose light profile images are used to fit the imaging data.
sky
Model component used to represent the background sky emission in an image (e.g. a `Sky` light profile).
dataset_model
Attributes which allow for parts of a dataset to be treated as a model (e.g. the background sky level).
adapt_images
Contains the adapt-images which are used to make a pixelization's mesh and regularization adapt to the
reconstructed galaxy's morphology.
Expand All @@ -72,12 +72,11 @@ def __init__(
decorator take to run.
"""

super().__init__(dataset=dataset, run_time_dict=run_time_dict)
super().__init__(dataset=dataset, dataset_model=dataset_model, run_time_dict=run_time_dict)
AbstractFitInversion.__init__(
self=self, model_obj=tracer, sky=sky, settings_inversion=settings_inversion
self=self, model_obj=tracer, settings_inversion=settings_inversion
)

self.sky = sky
self.tracer = tracer

self.adapt_images = adapt_images
Expand All @@ -96,16 +95,11 @@ def blurred_image(self) -> aa.Array2D:

if self.preloads.blurred_image is None:

if isinstance(self.sky, ag.lp.Sky):
image = self.sky.image_2d_from(grid=self.dataset.grid)
else:
image = np.zeros(self.dataset.shape_slim)

return self.tracer.blurred_image_2d_from(
grid=self.dataset.grid,
convolver=self.dataset.convolver,
blurring_grid=self.dataset.blurring_grid,
) + image
)

return self.preloads.blurred_image

Expand All @@ -115,7 +109,7 @@ def profile_subtracted_image(self) -> aa.Array2D:
Returns the dataset's image with all blurred light profile images in the fit's tracer subtracted.
"""

return self.image - self.blurred_image
return self.data - self.blurred_image

@property
def tracer_to_inversion(self) -> TracerToInversion:
Expand All @@ -135,7 +129,6 @@ def tracer_to_inversion(self) -> TracerToInversion:
return TracerToInversion(
dataset=dataset,
tracer=self.tracer,
sky=self.sky,
adapt_images=self.adapt_images,
settings_inversion=self.settings_inversion,
preloads=self.preloads,
Expand Down Expand Up @@ -302,7 +295,7 @@ def subtracted_images_of_planes_list(self) -> List[aa.Array2D]:
if i != galaxy_index
]

subtracted_image = self.image - sum(other_planes_model_images)
subtracted_image = self.data - sum(other_planes_model_images)

subtracted_images_of_planes_list.append(subtracted_image)

Expand Down Expand Up @@ -384,7 +377,7 @@ def refit_with_new_preloads(
return FitImaging(
dataset=self.dataset,
tracer=self.tracer,
sky=self.sky,
dataset_model=self.dataset_model,
adapt_images=self.adapt_images,
settings_inversion=settings_inversion,
preloads=preloads,
Expand Down
4 changes: 2 additions & 2 deletions autolens/imaging/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def fit_from(
instance=instance, run_time_dict=run_time_dict
)

sky = self.sky_via_instance_from(instance=instance)
dataset_model = self.dataset_model_via_instance_from(instance=instance)

adapt_images = self.adapt_images_via_instance_from(instance=instance)

Expand All @@ -168,7 +168,7 @@ def fit_from(
return FitImaging(
dataset=self.dataset,
tracer=tracer,
sky=sky,
dataset_model=dataset_model,
adapt_images=adapt_images,
settings_inversion=self.settings_inversion,
preloads=preloads,
Expand Down
14 changes: 9 additions & 5 deletions autolens/interferometer/fit_interferometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
self,
dataset: aa.Interferometer,
tracer: Tracer,
dataset_model: Optional[aa.DatasetModel] = None,
adapt_images: Optional[ag.AdaptImages] = None,
settings_inversion: aa.SettingsInversion = aa.SettingsInversion(),
preloads: Preloads = Preloads(),
Expand Down Expand Up @@ -54,6 +55,8 @@ def __init__(
The interforometer dataset which is fitted by the galaxies in the tracer.
tracer
The tracer of galaxies whose light profile images are used to fit the interferometer data.
dataset_model
Attributes which allow for parts of a dataset to be treated as a model (e.g. the background sky level).
adapt_images
Contains the adapt-images which are used to make a pixelization's mesh and regularization adapt to the
reconstructed galaxy's morphology.
Expand Down Expand Up @@ -82,9 +85,9 @@ def __init__(

self.run_time_dict = run_time_dict

super().__init__(dataset=dataset, run_time_dict=run_time_dict)
super().__init__(dataset=dataset, dataset_model=dataset_model, run_time_dict=run_time_dict)
AbstractFitInversion.__init__(
self=self, model_obj=tracer, sky=None, settings_inversion=settings_inversion
self=self, model_obj=tracer, settings_inversion=settings_inversion
)

@property
Expand All @@ -103,7 +106,7 @@ def profile_subtracted_visibilities(self) -> aa.Visibilities:
Returns the interferometer dataset's visibilities with all transformed light profile images in the fit's
tracer subtracted.
"""
return self.visibilities - self.profile_visibilities
return self.data - self.profile_visibilities

@property
def tracer_to_inversion(self) -> TracerToInversion:
Expand Down Expand Up @@ -219,7 +222,7 @@ def model_visibilities_of_planes_list(self) -> List[aa.Visibilities]:
galaxy_model_visibilities_dict = self.galaxy_model_visibilities_dict

model_visibilities_of_planes_list = [
aa.Visibilities.zeros(shape_slim=(self.dataset.visibilities.shape_slim,))
aa.Visibilities.zeros(shape_slim=(self.dataset.data.shape_slim,))
for i in range(self.tracer.total_planes)
]

Expand Down Expand Up @@ -274,8 +277,9 @@ def refit_with_new_preloads(
settings_inversion = self.settings_inversion

return FitInterferometer(
dataset=self.interferometer,
dataset=self.dataset,
tracer=self.tracer,
dataset_model=self.dataset_model,
adapt_images=self.adapt_images,
settings_inversion=settings_inversion,
preloads=preloads,
Expand Down
2 changes: 1 addition & 1 deletion autolens/interferometer/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def profile_log_likelihood_function(
instance=instance,
)

info_dict["number_of_visibilities"] = self.dataset.visibilities.shape[0]
info_dict["number_of_visibilities"] = self.dataset.data.shape[0]
info_dict["transformer_cls"] = self.dataset.transformer.__class__.__name__

self.output_profiling_info(
Expand Down
5 changes: 0 additions & 5 deletions autolens/lens/to_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __init__(
self,
dataset: Optional[Union[aa.Imaging, aa.Interferometer, aa.DatasetInterface]],
tracer,
sky: Optional[Basis] = None,
adapt_images: Optional[ag.AdaptImages] = None,
settings_inversion: aa.SettingsInversion = aa.SettingsInversion(),
preloads=Preloads(),
Expand All @@ -26,7 +25,6 @@ def __init__(

super().__init__(
dataset=dataset,
sky=sky,
adapt_images=adapt_images,
settings_inversion=settings_inversion,
preloads=preloads,
Expand Down Expand Up @@ -106,7 +104,6 @@ def lp_linear_func_list_galaxy_dict(
galaxies_to_inversion = ag.GalaxiesToInversion(
dataset=dataset,
galaxies=galaxies,
sky=self.sky,
settings_inversion=self.settings_inversion,
adapt_images=self.adapt_images,
run_time_dict=self.run_time_dict,
Expand Down Expand Up @@ -168,7 +165,6 @@ def image_plane_mesh_grid_pg_list(self) -> List[List]:
to_inversion = ag.GalaxiesToInversion(
dataset=self.dataset,
galaxies=galaxies,
sky=self.sky,
adapt_images=self.adapt_images,
settings_inversion=self.settings_inversion,
run_time_dict=self.run_time_dict,
Expand Down Expand Up @@ -245,7 +241,6 @@ def mapper_galaxy_dict(self) -> Dict[aa.AbstractMapper, ag.Galaxy]:
to_inversion = ag.GalaxiesToInversion(
dataset=self.dataset,
galaxies=galaxies,
sky=self.sky,
preloads=self.preloads,
adapt_images=self.adapt_images,
settings_inversion=self.settings_inversion,
Expand Down
3 changes: 0 additions & 3 deletions autolens/point/fit_point/positions_image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from functools import partial
from typing import Optional

import autoarray as aa
import autogalaxy as ag

from autolens.point.point_dataset import PointDict
from autolens.point.point_dataset import PointDataset
from autolens.point.point_solver import PointSolver
from autolens.lens.tracer import Tracer

Expand Down
5 changes: 5 additions & 0 deletions test_autolens/aggregator/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ def aggregator_from(database_file, analysis, model, samples):

@pytest.fixture(name="model")
def make_model():

dataset_model = af.Model(al.DatasetModel)
dataset_model.background_sky_level = af.UniformPrior(lower_limit=0.5, upper_limit=1.5)

return af.Collection(
dataset_model=dataset_model,
galaxies=af.Collection(
lens=af.Model(al.Galaxy, redshift=0.5, light=al.lp.Sersic),
source=af.Model(al.Galaxy, redshift=1.0, light=al.lp.Sersic),
Expand Down
2 changes: 2 additions & 0 deletions test_autolens/aggregator/test_aggregator_fit_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def test__fit_imaging_randomly_drawn_via_pdf_gen_from(
assert fit_list[0].tracer.galaxies[0].light.centre == (10.0, 10.0)
assert fit_list[0].tracer.galaxies[1].redshift == 1.0

assert fit_list[0].dataset_model.background_sky_level == 10.0

assert i == 2

clean(database_file=database_file)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def test__fit_interferometer_randomly_drawn_via_pdf_gen_from(
assert fit_list[0].tracer.galaxies[0].light.centre == (10.0, 10.0)
assert fit_list[0].tracer.galaxies[1].redshift == 1.0

assert fit_list[0].dataset_model.background_sky_level == 10.0

assert i == 2

clean(database_file=database_file)
Expand Down