Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

998 improving abc methods for trial based data using statistical distances #1104

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

theogruner
Copy link
Contributor

@theogruner theogruner commented Mar 22, 2024

What does this implement/fix? Explain your changes

Adding an approximation of the squared 2-Wasserstein distance based on Sinkhorn iterations as an additional statistical distance to the available metrics. Furthermore, extending MCABC and SMCABC to allow conditioning on multiple observations using statistical distances.

Does this close any currently open issues?

Fixes #998

Checklist

  • I have read and understood the contribution
    guidelines
  • I agree with re-licensing my contribution from AGPLv3 to Apache-2.0.
  • I have commented my code, particularly in hard-to-understand areas
  • I have added tests that prove my fix is effective or that my feature works
  • I have reported how long the new tests run and potentially marked them
    with pytest.mark.slow.
  • New and existing unit tests pass locally with my changes
  • I performed linting and formatting as described in the contribution
    guidelines
  • I rebased on main (or there are no conflicts with main)

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, this is great!
Added a couple of comments.

sbi/inference/abc/abc_base.py Outdated Show resolved Hide resolved
@@ -98,7 +127,47 @@ def l2_distance(xo, x):
def l1_distance(xo, x):
return torch.mean(abs(xo - x), dim=-1)

distance_functions = {"mse": mse_distance, "l2": l2_distance, "l1": l1_distance}
def mmd_squared(xo, x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given we have so many different distance functions now, I think it is time to refactor this and move them out of this function to the top level or to a separate file. would you agree?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add types and docstrings as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. My suggestion would be to add a Distance or Metric class in the sbi/utils/metrics.py which builds one of the chosen distances or a custom one. We can further set the allow_iid flag within the new class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now moved the distance functions to a separate distance class within the abc folder. I did not want to put it into sbi/utils/metrics.py as its implementation is specific to ABC and should not be used outside of it.

sbi/inference/abc/abc_base.py Outdated Show resolved Hide resolved
if isinstance(distance_type, Callable):
return distance_type
if allow_iid is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens with allow_iid else? I think it would be good to not keep is None, but to set it to either True or False at the beginning of this functions. Otherwise, pyright will likely complain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the type already specified to be a bool or None? Therefore, if it not None upright will treat it as a bool.


return distance(observed_data, simulated_data)

return distance_fun
is_statistical_distance = distance_type in implemented_statistical_distances
if allow_iid is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allow_iid should be True or False

sbi/utils/metrics.py Show resolved Hide resolved
tests/abc_test.py Outdated Show resolved Hide resolved
tests/abc_test.py Show resolved Hide resolved
tests/abc_test.py Show resolved Hide resolved
@@ -128,3 +130,35 @@ def test_c2st_scores(dist_sigma, c2st_lowerbound, c2st_upperbound):
assert obs2_c2st.mean() <= c2st_upperbound

assert np.allclose(obs2_c2st, obs_c2st, atol=0.05)


@pytest.mark.slow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add tests for the other distances as well? that'd be great!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I'll add them :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests for the unbiased and biased MMD based on hypothesis tests.

@theogruner theogruner requested a review from janfb March 29, 2024 17:37
…hods-for-trial-based-data-using-statistical-distances
@theogruner theogruner marked this pull request as ready for review April 2, 2024 19:40
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for the edits!

I added some last comments, once those addressed, the PR can be merged 🎉

@@ -54,7 +66,9 @@ def __init__(
self.x_shape = None

# Select distance function.
self.distance = self.get_distance_function(distance)
self.distance = Distance(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Comment on lines +22 to +24
distance:
requires_iid_data:
distance_kwargs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add short desc.

except KeyError as exc:
raise KeyError(f"Distance {distance} not supported.") from exc

def __call__(self, xo, x) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __call__(self, xo, x) -> torch.Tensor:
def __call__(self, x_o, x) -> torch.Tensor:

"""Distance evaluation between the reference data and the simulated data.

Args:
xo: Reference data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
xo: Reference data
x_o: Observed data

"""
if self.requires_iid_data:
assert x.ndim >= 3, "simulated data needs batch dimension"
assert xo.ndim + 1 == x.ndim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert xo.ndim + 1 == x.ndim
assert x_o.ndim + 1 == x.ndim

else:
assert x.ndim >= 2, "simulated data needs batch dimension"
if self.batch_size == -1:
return self.distance_fn(xo, x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return self.distance_fn(xo, x)
return self.distance_fn(x_o, x)

if self.batch_size == -1:
return self.distance_fn(xo, x)
else:
return self._batched_distance(xo, x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return self._batched_distance(xo, x)
return self._batched_distance(x_o, x)

Comment on lines +80 to +103
def _batched_distance(self, xo, x):
"""Evaluate the distance is mini-batches.
Especially for statistical distances, batching over two empirical
datasets can lead to memory overflow. Batching can help to resolve
the memory problems.

Args:
xo: Reference data
x: Simulated data
"""
num_batches = x.shape[0] // self.batch_size - 1
remaining = x.shape[0] % self.batch_size
if remaining == 0:
remaining = self.batch_size

distances = torch.empty(x.shape[0])
for i in tqdm(range(num_batches)):
distances[self.batch_size * i : (i + 1) * self.batch_size] = (
self.distance_fn(xo, x[self.batch_size * i : (i + 1) * self.batch_size])
)
if remaining > 0:
distances[-remaining:] = self.distance_fn(xo, x[-remaining:])

return distances
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename xo to x_o as above.

@janfb janfb self-assigned this Apr 4, 2024
@janfb
Copy link
Contributor

janfb commented Apr 12, 2024

Ping @theogruner 🙂 We are so close to merging here.

@janfb
Copy link
Contributor

janfb commented Apr 24, 2024

Ping @theogruner will you have time to finish this soon?
Please let me know if not, so that we can go ahead for the release - thanks 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improving ABC methods for trial-based data using statistical distances
2 participants