-
Notifications
You must be signed in to change notification settings - Fork 128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add: prototype implementation of tarp in sbi #1106
base: main
Are you sure you want to change the base?
Conversation
Note, I am stopping work on this PR for the time being. I ran into issues reproducing the tarp paper: Ciela-Institute/tarp#8 |
- does not work yet
9e4bde5
to
f7c3cc2
Compare
c331c6d
to
6bde631
Compare
Dear @janfb and @JuliaLinhart, I'd have some questions though:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot for contributing this @psteinb ! 🙏
I did a first pass (not the tests) and added a couple of comments.
Two high-level comments:
- do we really need a
TARP
class? Couldn't we do it in a function gets the samples, thetas, and a bunch ofkwargs
and basically does whatcheck
currently is doing? - It seems that
TARP
is very similar toSBC
except that it uses a differentreduce_fn
(l2
orl1
) for the ranking, and that it needs thereferences
, no?
Thus, maybe there is way to incorporate it into the current implementation ofsbc.py
?
def l2(x: Tensor, y: Tensor, axis=-1) -> Tensor: | ||
""" | ||
Calculates the L2 distance between two tensors. | ||
Args: | ||
x (Tensor): The first tensor. | ||
y (Tensor): The second tensor. | ||
axis (int, optional): The axis along which to calculate the L2 distance. | ||
Defaults to -1. | ||
Returns: | ||
Tensor: A tensor containing the L2 distance between x and y along the | ||
specified axis. | ||
""" | ||
return torch.sqrt(torch.sum((x - y) ** 2, axis=axis)) | ||
|
||
|
||
def l1(x: Tensor, y: Tensor, axis=-1) -> Tensor: | ||
""" | ||
Calculates the L1 distance between two tensors. | ||
Args: | ||
x (Tensor): The first tensor. | ||
y (Tensor): The second tensor. | ||
axis (int, optional): The axis along which to calculate the L1 distance. | ||
Defaults to -1. | ||
Returns: | ||
Tensor: A tensor containing the L1 distance between x and y along the | ||
specified axis. | ||
""" | ||
return torch.sum(torch.abs(x - y), axis=axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to move these into sbi/utils/metrics
posterior.set_default_x(xo) | ||
posterior.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
posterior.set_default_x(xo) | |
posterior.train() | |
posterior.train() |
|
||
def __init__( | ||
self, | ||
references: Tensor = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
references: Tensor = None, | |
references: Optional[Tensor] = None, |
self, | ||
references: Tensor = None, | ||
metric: str = "euclidean", | ||
num_alpha_bins: Union[int, None] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will be set to n_bins
in __init__
so I suggest to just rename it and be consistent with num_
num_alpha_bins: Union[int, None] = None, | |
num_bins: Optional[int] = None, |
num_alpha_bins: number of bins to use for the credibility values. | ||
If ``None``, then ``n_sims // 10`` bins are used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_alpha_bins: number of bins to use for the credibility values. | |
If ``None``, then ``n_sims // 10`` bins are used. | |
num_bins: number of bins to use for the credibility values. | |
If ``None``, then ``num_sims // 10`` bins are used. |
if theta.shape[-2] != num_sims: | ||
raise ValueError("theta must have the same number of rows as samples") | ||
if theta.shape[-1] != num_dims: | ||
raise ValueError("theta must have the same number of columns as samples") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if theta.shape[-2] != num_sims: | |
raise ValueError("theta must have the same number of rows as samples") | |
if theta.shape[-1] != num_dims: | |
raise ValueError("theta must have the same number of columns as samples") | |
theta.shape == samples.shape[1:], "number and dimensions of ground truth thetas must match the posterior samples." |
""" | ||
# TARP assumes that the predicted thetas are sampled from the "true" | ||
# PDF num_samples times | ||
theta = theta.detach() if len(theta.shape) != 2 else theta.detach().unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we assert already here that theta.shape == samples.shape[1:]
? Why do we need the unsqueeze(0)
?
samples = (samples - lo) / (hi - lo + 1e-10) | ||
theta = (theta - lo) / (hi - lo + 1e-10) | ||
|
||
assert len(theta.shape) == len(samples.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused by this assert. I assumed that samples
always has one dimension more than theta
because it contains samples for many different x
?
if not isinstance(self.references, Tensor): | ||
# obtain min/max per dimension of theta | ||
lo = ( | ||
torch.min(theta, dim=-2).values.min(axis=0).values | ||
) # should be 0 if normalized | ||
hi = ( | ||
torch.max(theta, dim=-2).values.max(axis=0).values | ||
) # should be 1 if normalized | ||
|
||
refpdf = torch.distributions.Uniform(low=lo, high=hi) | ||
self.references = refpdf.sample((1, num_sims)) | ||
else: | ||
if len(self.references.shape) == 2: | ||
# add singleton dimension in front | ||
self.references = self.references.unsqueeze(0) | ||
|
||
if len(self.references.shape) == 3 and self.references.shape[0] != 1: | ||
raise ValueError( | ||
f"""references must be a 2D array with a singular first | ||
dimension, received {self.references.shape}""" | ||
) | ||
|
||
if self.references.shape[-2] != num_sims: | ||
raise ValueError( | ||
f"references must have the same number samples as samples," | ||
f"received {self.references.shape[-2]} != {num_sims}" | ||
) | ||
|
||
if self.references.shape[-1] != num_dims: | ||
raise ValueError( | ||
"references must have the same number of dimensions as " | ||
f"samples or theta, received {self.references.shape[-1]}" | ||
f"!= {num_dims}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems that these lines act only on self.references
. I suggest to move them to a separate function, e.g., def _check_references(...)
and call this function only once during init. Or am I missing something?
if self.metric_name.lower() in ["l2", "euclidean"]: | ||
distance = l2 | ||
elif self.metric_name.lower() in ["l1", "manhattan"]: | ||
distance = l1 | ||
else: | ||
raise ValueError( | ||
"metric must be either 'euclidean' or 'manhattan'," | ||
f"received {self.metric_name}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could be done during __init__
as well, and then just set self.distance
.
What does this implement/fix? Explain your changes
Tarp is a diagnotics method, which can help identify over-/underdispersion and bias in trained neural posteriors. The corresponding paper is located here:
https://arxiv.org/abs/2302.03026
the repo code for numpy is here:
https://github.com/Ciela-Institute/tarp/
Does this close any currently open issues?
No, this was part of the Mar 2024 SBI hackathon in Tübingen
Any relevant code examples, logs, error output, etc?
Not yet, I am trying to reproduce the examples given in the paper. At a later point in time, I'd like to bring the tests as well as a tutorial in line to what is available with sbc.
Checklist
Put an
x
in the boxes that apply. You can also fill these out after creatingthe PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.
guidelines
with
pytest.mark.slow
.guidelines
main
(or there are no conflicts withmain
)