You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the ANP implementation, during training mode, the full set (i.e. context + target) is being passed to predict meaning that log likelihood is being maximized for both target and context labels.
ifself.training:
pz=self.lenc(batch.xc, batch.yc)
qz=self.lenc(batch.x, batch.y)
z=qz.rsample() ifnum_samplesisNoneelse \
qz.rsample([num_samples])
py=self.predict(batch.xc, batch.yc, batch.x,
z=z, num_samples=num_samples)
ifnum_samples>1:
# K * B * Nrecon=py.log_prob(stack(batch.y, num_samples)).sum(-1)
# K * Blog_qz=qz.log_prob(z).sum(-1)
log_pz=pz.log_prob(z).sum(-1)
# K * Blog_w=recon.sum(-1) +log_pz-log_qzouts.loss=-logmeanexp(log_w).mean() /batch.x.shape[-2]
In the ANP implementation, during training mode, the full set (i.e. context + target) is being passed to
predict
meaning that log likelihood is being maximized for both target and context labels.Shouldn't this line:
be changed to this
i.e.
batch.x
-->batch.xt
in 3rd argument ofpredict
?The text was updated successfully, but these errors were encountered: