Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API simplify DensityEstimator base class #1072

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

tomMoral
Copy link
Contributor

@tomMoral tomMoral commented Mar 21, 2024

Simplify the DensityEstimator base class and make it an abstract class.
Remove unnecessary functions

A step in the direction of #1046

Copy link

codecov bot commented Mar 21, 2024

Codecov Report

Attention: Patch coverage is 84.84848% with 5 lines in your changes are missing coverage. Please review.

Project coverage is 76.86%. Comparing base (0b5f931) to head (d47f4a7).

❗ Current head d47f4a7 differs from pull request most recent head c6ff223. Consider uploading reports for the commit c6ff223 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1072      +/-   ##
==========================================
- Coverage   85.13%   76.86%   -8.27%     
==========================================
  Files          90       89       -1     
  Lines        6651     6558      -93     
==========================================
- Hits         5662     5041     -621     
- Misses        989     1517     +528     
Flag Coverage Δ
unittests 76.86% <84.84%> (-8.27%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
sbi/inference/snpe/snpe_a.py 65.15% <100.00%> (-25.07%) ⬇️
sbi/neural_nets/density_estimators/nflows_flow.py 98.21% <100.00%> (+35.46%) ⬆️
sbi/neural_nets/density_estimators/zuko_flow.py 97.91% <100.00%> (+33.47%) ⬆️
sbi/neural_nets/density_estimators/base.py 65.00% <66.66%> (+9.44%) ⬆️
sbi/utils/user_input_checks.py 85.36% <62.50%> (+1.60%) ⬆️

... and 51 files with indirect coverage changes

@tomMoral
Copy link
Contributor Author

Finally it's green ^^

@manuelgloeckler
Copy link
Contributor

manuelgloeckler commented Mar 22, 2024

Great :). I just checked with @michaeldeistler, who also works on the Density estimator functionality in #1066; so, we will delay merging this until #1066 is done.

@tomMoral tomMoral mentioned this pull request Mar 23, 2024
8 tasks
@manuelgloeckler manuelgloeckler added the blocked Something is in the way of fixing this. Refer to it in the issue label Apr 2, 2024
@janfb janfb self-assigned this Apr 4, 2024
@tomMoral
Copy link
Contributor Author

tomMoral commented Apr 24, 2024

#1138 conflicts with this PR.
As this PR modifies the base class, merging it fast would be a good idea otherwise we should drop it as it will always need to resolve conflicts.

I am not sure how to solve the conflict here:

  • From our discussion during the sprint, I thought the goal was to avoid having constraints on the base class and went on removing the attribute that were not used in it (net/_condition_shape)
  • Give all DensityEstimators an input_shape #1138 adds an attribute which is not used in the class.

I think if we have input/condition_shape attributes, it would make sense to add a mechanism to check these shapes in all the class function. something like:

def _check_shape(self, x=None, theta=None):
    if x is not None:
        check_input_shape(x, self._input_shape)
    if theta is not None:
        check_condition_shape(theta, self.condition_shape)

maybe this could also handle some of the reshaping from #1066?

@janfb
Copy link
Contributor

janfb commented Apr 25, 2024

Yes, this has a lot of conflicts with the changes to the shaping we made and we want this PR to be merged asap.

@manuelgloeckler planned to continuing this PR. I think he will either merge this branch into a new feature branch from main, or to start from scratch and cherry pick the changes from this PR.

I see the point about the _check_shape function, but I think we have that already, at least for the condition shape:

def _check_condition_shape(self, condition: Tensor):
r"""This method checks whether the condition has the correct shape.
Args:
condition: Conditions of shape (*batch_shape, *condition_shape).
Raises:
ValueError: If the condition has a dimensionality that does not match
the expected input dimensionality.
ValueError: If the shape of the condition does not match the expected
input dimensionality.
"""
if len(condition.shape) < len(self.condition_shape):
raise ValueError(
f"Dimensionality of condition is to small and does not match the\
expected input dimensionality {len(self.condition_shape)}, as provided\
by condition_shape."
)
else:
condition_shape = condition.shape[-len(self.condition_shape) :]
if tuple(condition_shape) != tuple(self.condition_shape):
raise ValueError(
f"Shape of condition {tuple(condition_shape)} does not match the \
expected input dimensionality {tuple(self.condition_shape)}, as \
provided by condition_shape. Please reshape it accordingly."
)

And we also have other functions that are used by the SBI methods to check the correctness of the shape, e.g.,

def reshape_to_batch_event(theta_or_x: Tensor, event_shape: torch.Size) -> Tensor:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
blocked Something is in the way of fixing this. Refer to it in the issue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants