Skip to content

Commit

Permalink
Fixed ct_specific_expression.
Browse files Browse the repository at this point in the history
  • Loading branch information
canergen committed Jul 29, 2022
1 parent cbfa977 commit 365a4a0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions scvi/module/_mrdeconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def get_ct_specific_expression(
# cell-type specific gene expression, shape (minibatch, celltype, gene).
eps = torch.nn.functional.softplus(self.eta) # n_genes
eps = eps.repeat((x.shape[0], 1)).view(
x.shape[0], 1, -1
x.shape[0], 1, x.shape[1]
) # (M, 1, n_genes) <- this is the dummy cell type
beta = torch.exp(self.beta) # n_genes

Expand All @@ -447,7 +447,7 @@ def get_ct_specific_expression(
else:
v_ind = torch.nn.functional.softplus(
self.V[:, ind_x]
) # n_spots, n_labels + 1
).T # n_spots, n_labels + 1
# remove dummy cell type proportion values

if self.amortization in ["both", "latent"]:
Expand All @@ -473,13 +473,13 @@ def get_ct_specific_expression(
px_ct = torch.cat(
[
beta.unsqueeze(0).unsqueeze(1)
* px_scale.reshape(-1, self.n_labels, self.n_latent),
* px_scale.reshape(-1, self.n_labels, self.n_genes),
eps,
],
dim=1,
)
expression = torch.expm1(x) * (
(v_ind[:, y] * px_ct[:, y, :])
(v_ind[:, y].unsqueeze(1) * px_ct[:, y, :])
/ torch.sum(v_ind.unsqueeze(2) * px_ct, dim=1)
)

Expand Down

0 comments on commit 365a4a0

Please sign in to comment.