Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD Cleanup: Avoid dealing with invalid IR by representing out-of-scope accesses through Push and Pop instructions #4182

Open
saipraveenb25 opened this issue May 16, 2024 · 0 comments
Assignees
Labels
goal:forward looking kind:cleanup tech debt and rough edges priority:medium nice to have in next milestone

Comments

@saipraveenb25
Copy link
Collaborator

The unzipping step creates two copies of the control-flow graph, one holding primal instructions and the other holding differential instructions. The differential blocks run immediately after all the primal ones, so if there is any non-trivial branching, there will be invalid accesses of instructions generated in conditional regions.

Example:

Consider a function f with control-flow:

float f(float x)
{
    float val = 0.f;
    if (x < 0.5)
    {
        float k = 2 * x;
        val = k * x;
    }
    else
    {
        float k = 3 * x;
        val = k * x;
    }

    return val;
}

This will result in the following code after the unzipping step:

DifferentialPair<float> unzipped_f(DifferentialPair<float> dpx)
{
    float val = 0.f;
    float val_d = 0.f;

    if (dpx.p < 0.5)
    {
        float k = 2 * dpx.p;
        val = k * dpx.p;
    }
    else
    {
        float k = 3 * dpx.p;
        val = k * dpx.p;
    }

    if (dpx.p < 0.5)
    {
        float d_k = 2 * dpx.d;

        // Invalid use of 'k' which is defined in an inaccessible scope.
        val_d = d_k * dpx.p + k * d_x; 
    }
    else
    {
        float d_k = 3 * dpx.d;
        
        // Invalid use of 'k' which is defined in an inaccessible scope.
        val_d = d_k * dpx.p + k * d_x;
    }

    return DifferentialPair<float>(val, val_d);
}

This contains invalid out-of-scope accesses.
For now, we allow this until the last step when the checkpointing pass legalizes all these accesses by inserting loads and stores as necessary.

This can be cleaned up in a more principled way by using ideas from "Tape-based AD". This approach assumes a dynamic infinite 'stack' that the primal code can push intermediate values to. The differential blocks then 'pop' this information out and use it.

We can implement this approach by introducing instructions like IRTapePush and IRTapePop

This allows us to have different implementations depending on how we lower these instructions. The current approach would simply be the static method where we introduce a struct with fields for all unique IRTapePush instructions (and arrays wherever IRTapePush occurs within a loop region).
This could also enable user written derivative code to make use of the intermediate context. Right now, reverse-mode user-written derivatives need to use workarounds such as non-differentiable auxiliary parameters, or write a fully self-contained derivative method.

@saipraveenb25 saipraveenb25 added kind:cleanup tech debt and rough edges priority:medium nice to have in next milestone goal:forward looking labels May 16, 2024
@saipraveenb25 saipraveenb25 added this to the Q3 2024 (Summer) milestone May 16, 2024
@saipraveenb25 saipraveenb25 self-assigned this May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
goal:forward looking kind:cleanup tech debt and rough edges priority:medium nice to have in next milestone
Projects
None yet
Development

No branches or pull requests

1 participant