-
Notifications
You must be signed in to change notification settings - Fork 342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes error in get losses functions #2362
base: main
Are you sure you want to change the base?
Changes from 21 commits
22828b6
821129a
d8b5ce7
5096aa1
f149d3b
b82c1b0
afa83e6
b0bca37
0ea7972
4efe0e2
b066b95
34e03d6
ddfea2d
b484700
e233575
ee1aeb5
0a7b36e
24f34e0
0ca691f
616c486
7ec29c2
763c3c1
8ea61f3
4da32ba
c16b69a
bebea1e
a3feb93
37002f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -473,7 +473,7 @@ def loss( | |
generative_outputs["pl"], | ||
).sum(dim=1) | ||
else: | ||
kl_divergence_l = torch.tensor(0.0, device=x.device) | ||
kl_divergence_l = torch.zeros_like(kl_divergence_z) | ||
|
||
reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1) | ||
|
||
|
@@ -609,6 +609,8 @@ def marginal_ll( | |
q_l_x = ql.log_prob(library).sum(dim=-1) | ||
|
||
log_prob_sum += p_l - q_l_x | ||
if n_mc_samples_per_pass == 1: | ||
log_prob_sum = log_prob_sum.unsqueeze(0) | ||
Comment on lines
+612
to
+613
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait so does this mean that this method was not working properly before? Since the default is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes but we only use it for DE genes where it's called with n_samples_per_mc of 100 (?). ScANVI wasn't supporting importance weighting in DE beforehand. |
||
|
||
to_sum.append(log_prob_sum) | ||
to_sum = torch.cat(to_sum, dim=0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably best to also use torch operations here since
compute_reconstruction_error
doesn't use numpy. Let me know what you thinkThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine with me. Can you puh those changes?