Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to determine cond_fn in condition_mean #69

Open
cyuting940612 opened this issue Nov 22, 2023 · 4 comments
Open

How to determine cond_fn in condition_mean #69

cyuting940612 opened this issue Nov 22, 2023 · 4 comments

Comments

@cyuting940612
Copy link

Hello everyone,

I'm trying to implement the manipulate method. And I find the key to manipulate is to modify the cond in render_condition() function. And the condition_mean() function will calculate a new mean for diffusion for sampling. However, I can not see where the cond_fn comes from and how to determine it. Much appreciate if you have any clue!

def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
    """
    Compute the mean for the previous step, given a function cond_fn that
    computes the gradient of a conditional log probability with respect to
    x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
    condition on y.

    This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
    """
    gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
    new_mean = (p_mean_var["mean"].float() +
                p_mean_var["variance"] * gradient.float())
    return new_mean

        pred_img = render_condition(self.conf,
                                    self.ema_model,
                                    noise,
                                    sampler=sampler,
                                    cond=cond)
@phizaz
Copy link
Owner

phizaz commented Nov 22, 2023

Can you please point to the line of the code in question?

@cyuting940612
Copy link
Author

Thanks for your reply.

Yes of course. The condition_mean method is Line 403-416 in base.py

Also in the manipulate case:

cond2 = cond2 + 0.3 * math.sqrt(512) * F.normalize(cls_model.classifier.weight[cls_id][None, :], dim=1)

could you please share how you determine the coefficient (0.3*sqrt(512))?

Thanks for your help!

@phizaz
Copy link
Owner

phizaz commented Nov 23, 2023

I believe the line 403-416 was not used in DiffAE at all, and hence not related to DiffAE. It was a legacy code from the base repo that we built upon.

Regarding your second question, 0.3*sqrt(512). 0.3 is definitely tuned by hand based on qualitative results. sqrt(512) is related to the fact that F.normalize(cls_model.classifier.weight[cls_id][None, :], dim=1) is a unit vector (which is rather small, and depends on the number of dims). It's nice to have something that scales with dims (make the coefficient more robust). I think a Gaussian random vector with 512 dimensions has the norm of sqrt(512), multiplying a unit vector with sqrt(512) scales it as though it has the same size as a random Gaussian vector.

@cyuting940612
Copy link
Author

Really appreciate your reply.

But I think line 403-416 actually used in img = model.render(xT, cond2, T=100) from block[33] of manipulate.ipynb.

In which the render method can be conditioned based on cond_fn. But I don't find anywhere cond_fn be defined.

So could you please provide me some clue on how to implement this conditional render?

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants