-
Notifications
You must be signed in to change notification settings - Fork 128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
998 improving abc methods for trial based data using statistical distances #1104
base: main
Are you sure you want to change the base?
998 improving abc methods for trial based data using statistical distances #1104
Conversation
… based on regularized optimal transport
…hods-for-trial-based-data-using-statistical-distances
…hods-for-trial-based-data-using-statistical-distances
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, this is great!
Added a couple of comments.
sbi/inference/abc/abc_base.py
Outdated
@@ -98,7 +127,47 @@ def l2_distance(xo, x): | |||
def l1_distance(xo, x): | |||
return torch.mean(abs(xo - x), dim=-1) | |||
|
|||
distance_functions = {"mse": mse_distance, "l2": l2_distance, "l1": l1_distance} | |||
def mmd_squared(xo, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
given we have so many different distance functions now, I think it is time to refactor this and move them out of this function to the top level or to a separate file. would you agree?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add types and docstrings as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. My suggestion would be to add a Distance or Metric class in the sbi/utils/metrics.py
which builds one of the chosen distances or a custom one. We can further set the allow_iid
flag within the new class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I now moved the distance functions to a separate distance class within the abc folder. I did not want to put it into sbi/utils/metrics.py
as its implementation is specific to ABC and should not be used outside of it.
sbi/inference/abc/abc_base.py
Outdated
if isinstance(distance_type, Callable): | ||
return distance_type | ||
if allow_iid is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens with allow_iid
else? I think it would be good to not keep is None
, but to set it to either True
or False
at the beginning of this functions. Otherwise, pyright
will likely complain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't the type already specified to be a bool
or None
? Therefore, if it not None
upright will treat it as a bool.
sbi/inference/abc/abc_base.py
Outdated
|
||
return distance(observed_data, simulated_data) | ||
|
||
return distance_fun | ||
is_statistical_distance = distance_type in implemented_statistical_distances | ||
if allow_iid is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allow_iid
should be True or False
@@ -128,3 +130,35 @@ def test_c2st_scores(dist_sigma, c2st_lowerbound, c2st_upperbound): | |||
assert obs2_c2st.mean() <= c2st_upperbound | |||
|
|||
assert np.allclose(obs2_c2st, obs_c2st, atol=0.05) | |||
|
|||
|
|||
@pytest.mark.slow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add tests for the other distances as well? that'd be great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I'll add them :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added tests for the unbiased and biased MMD based on hypothesis tests.
…hods-for-trial-based-data-using-statistical-distances
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for the edits!
I added some last comments, once those addressed, the PR can be merged 🎉
@@ -54,7 +66,9 @@ def __init__( | |||
self.x_shape = None | |||
|
|||
# Select distance function. | |||
self.distance = self.get_distance_function(distance) | |||
self.distance = Distance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
distance: | ||
requires_iid_data: | ||
distance_kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add short desc.
except KeyError as exc: | ||
raise KeyError(f"Distance {distance} not supported.") from exc | ||
|
||
def __call__(self, xo, x) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __call__(self, xo, x) -> torch.Tensor: | |
def __call__(self, x_o, x) -> torch.Tensor: |
"""Distance evaluation between the reference data and the simulated data. | ||
|
||
Args: | ||
xo: Reference data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xo: Reference data | |
x_o: Observed data |
""" | ||
if self.requires_iid_data: | ||
assert x.ndim >= 3, "simulated data needs batch dimension" | ||
assert xo.ndim + 1 == x.ndim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert xo.ndim + 1 == x.ndim | |
assert x_o.ndim + 1 == x.ndim |
else: | ||
assert x.ndim >= 2, "simulated data needs batch dimension" | ||
if self.batch_size == -1: | ||
return self.distance_fn(xo, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return self.distance_fn(xo, x) | |
return self.distance_fn(x_o, x) |
if self.batch_size == -1: | ||
return self.distance_fn(xo, x) | ||
else: | ||
return self._batched_distance(xo, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return self._batched_distance(xo, x) | |
return self._batched_distance(x_o, x) |
def _batched_distance(self, xo, x): | ||
"""Evaluate the distance is mini-batches. | ||
Especially for statistical distances, batching over two empirical | ||
datasets can lead to memory overflow. Batching can help to resolve | ||
the memory problems. | ||
|
||
Args: | ||
xo: Reference data | ||
x: Simulated data | ||
""" | ||
num_batches = x.shape[0] // self.batch_size - 1 | ||
remaining = x.shape[0] % self.batch_size | ||
if remaining == 0: | ||
remaining = self.batch_size | ||
|
||
distances = torch.empty(x.shape[0]) | ||
for i in tqdm(range(num_batches)): | ||
distances[self.batch_size * i : (i + 1) * self.batch_size] = ( | ||
self.distance_fn(xo, x[self.batch_size * i : (i + 1) * self.batch_size]) | ||
) | ||
if remaining > 0: | ||
distances[-remaining:] = self.distance_fn(xo, x[-remaining:]) | ||
|
||
return distances |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please rename xo
to x_o
as above.
Ping @theogruner 🙂 We are so close to merging here. |
Ping @theogruner will you have time to finish this soon? |
What does this implement/fix? Explain your changes
Adding an approximation of the squared 2-Wasserstein distance based on Sinkhorn iterations as an additional statistical distance to the available metrics. Furthermore, extending MCABC and SMCABC to allow conditioning on multiple observations using statistical distances.
Does this close any currently open issues?
Fixes #998
Checklist
guidelines
with
pytest.mark.slow
.guidelines
main
(or there are no conflicts withmain
)