Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: prototype implementation of tarp in sbi #1106

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

psteinb
Copy link
Contributor

@psteinb psteinb commented Mar 22, 2024

What does this implement/fix? Explain your changes

Tarp is a diagnotics method, which can help identify over-/underdispersion and bias in trained neural posteriors. The corresponding paper is located here:
https://arxiv.org/abs/2302.03026

the repo code for numpy is here:
https://github.com/Ciela-Institute/tarp/

Does this close any currently open issues?

No, this was part of the Mar 2024 SBI hackathon in Tübingen

Any relevant code examples, logs, error output, etc?

Not yet, I am trying to reproduce the examples given in the paper. At a later point in time, I'd like to bring the tests as well as a tutorial in line to what is available with sbc.

Checklist

Put an x in the boxes that apply. You can also fill these out after creating
the PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.

  • I have read and understood the contribution
    guidelines
  • I agree with re-licensing my contribution from AGPLv3 to Apache-2.0.
  • I have commented my code, particularly in hard-to-understand areas
  • I have added tests that prove my fix is effective or that my feature works
  • I have reported how long the new tests run and potentially marked them
    with pytest.mark.slow.
  • New and existing unit tests pass locally with my changes
  • I performed linting and formatting as described in the contribution
    guidelines
  • I rebased on main (or there are no conflicts with main)

@psteinb psteinb self-assigned this Mar 22, 2024
@psteinb
Copy link
Contributor Author

psteinb commented Mar 28, 2024

Note, I am stopping work on this PR for the time being. I ran into issues reproducing the tarp paper: Ciela-Institute/tarp#8
Once they are resolved, I'll continue working on this.

@psteinb psteinb closed this Mar 28, 2024
@psteinb psteinb reopened this Mar 28, 2024
@psteinb
Copy link
Contributor Author

psteinb commented Apr 12, 2024

Dear @janfb and @JuliaLinhart,
an alpha version of TARP (arxiv) as a SBI diagnostic is now ready from my point of view. I'd love someone of you to have a look. There are two files that I added sbi/diagnostics/tarp.py and tests/tarp_tests.py. The last unit test also documents how tarp would be used with SBI posterior predictions.
Feel free to have a look.

I'd have some questions though:

  • at this point, the tarp coverage estimates are returned as raw numbers, i.e. I don't perform any hypothesis testing on them, should I add (i.e. a KS test) that?

  • the TARP diagnostic class currently implements a run and a check function (to be aligned with the SBC code); the run function practically doesn't do anything TARP related but rather draws samples from the posterior, check actually performs tarp without any hypothesis test. I'm unclear if we should rather have run to compute the coverage stats and check do the hypothesis test for example. What do you think?

  • the TARP paper also offers a bootstrapped version of the diagnostic, would we want to have that in SBI too?

  • I think, if TARP is included in SBI, there should be a tutorial about it. I'd rather not make this part of this PR though. Is that OK?

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for contributing this @psteinb ! 🙏

I did a first pass (not the tests) and added a couple of comments.

Two high-level comments:

  1. do we really need a TARP class? Couldn't we do it in a function gets the samples, thetas, and a bunch of kwargs and basically does what check currently is doing?
  2. It seems that TARP is very similar to SBC except that it uses a different reduce_fn (l2 or l1) for the ranking, and that it needs the references, no?
    Thus, maybe there is way to incorporate it into the current implementation of sbc.py?

Comment on lines +15 to +42
def l2(x: Tensor, y: Tensor, axis=-1) -> Tensor:
"""
Calculates the L2 distance between two tensors.
Args:
x (Tensor): The first tensor.
y (Tensor): The second tensor.
axis (int, optional): The axis along which to calculate the L2 distance.
Defaults to -1.
Returns:
Tensor: A tensor containing the L2 distance between x and y along the
specified axis.
"""
return torch.sqrt(torch.sum((x - y) ** 2, axis=axis))


def l1(x: Tensor, y: Tensor, axis=-1) -> Tensor:
"""
Calculates the L1 distance between two tensors.
Args:
x (Tensor): The first tensor.
y (Tensor): The second tensor.
axis (int, optional): The axis along which to calculate the L1 distance.
Defaults to -1.
Returns:
Tensor: A tensor containing the L1 distance between x and y along the
specified axis.
"""
return torch.sum(torch.abs(x - y), axis=axis)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to move these into sbi/utils/metrics

Comment on lines +76 to +77
posterior.set_default_x(xo)
posterior.train()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
posterior.set_default_x(xo)
posterior.train()
posterior.train()


def __init__(
self,
references: Tensor = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
references: Tensor = None,
references: Optional[Tensor] = None,

self,
references: Tensor = None,
metric: str = "euclidean",
num_alpha_bins: Union[int, None] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will be set to n_bins in __init__ so I suggest to just rename it and be consistent with num_

Suggested change
num_alpha_bins: Union[int, None] = None,
num_bins: Optional[int] = None,

Comment on lines +120 to +121
num_alpha_bins: number of bins to use for the credibility values.
If ``None``, then ``n_sims // 10`` bins are used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_alpha_bins: number of bins to use for the credibility values.
If ``None``, then ``n_sims // 10`` bins are used.
num_bins: number of bins to use for the credibility values.
If ``None``, then ``num_sims // 10`` bins are used.

Comment on lines +251 to +254
if theta.shape[-2] != num_sims:
raise ValueError("theta must have the same number of rows as samples")
if theta.shape[-1] != num_dims:
raise ValueError("theta must have the same number of columns as samples")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if theta.shape[-2] != num_sims:
raise ValueError("theta must have the same number of rows as samples")
if theta.shape[-1] != num_dims:
raise ValueError("theta must have the same number of columns as samples")
theta.shape == samples.shape[1:], "number and dimensions of ground truth thetas must match the posterior samples."

"""
# TARP assumes that the predicted thetas are sampled from the "true"
# PDF num_samples times
theta = theta.detach() if len(theta.shape) != 2 else theta.detach().unsqueeze(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we assert already here that theta.shape == samples.shape[1:]? Why do we need the unsqueeze(0)?

samples = (samples - lo) / (hi - lo + 1e-10)
theta = (theta - lo) / (hi - lo + 1e-10)

assert len(theta.shape) == len(samples.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused by this assert. I assumed that samples always has one dimension more than theta because it contains samples for many different x?

Comment on lines +264 to +297
if not isinstance(self.references, Tensor):
# obtain min/max per dimension of theta
lo = (
torch.min(theta, dim=-2).values.min(axis=0).values
) # should be 0 if normalized
hi = (
torch.max(theta, dim=-2).values.max(axis=0).values
) # should be 1 if normalized

refpdf = torch.distributions.Uniform(low=lo, high=hi)
self.references = refpdf.sample((1, num_sims))
else:
if len(self.references.shape) == 2:
# add singleton dimension in front
self.references = self.references.unsqueeze(0)

if len(self.references.shape) == 3 and self.references.shape[0] != 1:
raise ValueError(
f"""references must be a 2D array with a singular first
dimension, received {self.references.shape}"""
)

if self.references.shape[-2] != num_sims:
raise ValueError(
f"references must have the same number samples as samples,"
f"received {self.references.shape[-2]} != {num_sims}"
)

if self.references.shape[-1] != num_dims:
raise ValueError(
"references must have the same number of dimensions as "
f"samples or theta, received {self.references.shape[-1]}"
f"!= {num_dims}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that these lines act only on self.references. I suggest to move them to a separate function, e.g., def _check_references(...) and call this function only once during init. Or am I missing something?

Comment on lines +303 to +311
if self.metric_name.lower() in ["l2", "euclidean"]:
distance = l2
elif self.metric_name.lower() in ["l1", "manhattan"]:
distance = l1
else:
raise ValueError(
"metric must be either 'euclidean' or 'manhattan',"
f"received {self.metric_name}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be done during __init__ as well, and then just set self.distance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants