Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to place arrows below w.r.t atop in diagram #124

Open
jerinphilip opened this issue Sep 1, 2023 · 5 comments
Open

Option to place arrows below w.r.t atop in diagram #124

jerinphilip opened this issue Sep 1, 2023 · 5 comments

Comments

@jerinphilip
Copy link

jerinphilip commented Sep 1, 2023

The connect_ family of functions do diagram + arrow as opposed to arrow + diagram.

return dia + arrow_between(ps, pe, style)

return dia + arrow_between(ps, pe, style)

return dia + arrow_between(ps, pe, style)

Since atop or + is not commutative, these have different output renderings and I've found a few use-cases for the latter. I think both are useful, and this is a feature request for specifying arrow + diagram or diagram + arrow through a switch.

To motivate the use-case with an example, left below is a reproduction of MultiHeadAttention from the original paper (right).

One process declaratively specifying this diagram involves stacking ScaledDotProductAttention (SDPA) for each head and the having Split branch out to each heads. connect_outside works for me here, but since I'm doing the connect after the stack is created - the arrows appear above, not honoring the depth dimension.

I have found some success for a custom Trail based arrow - see arrow_outside_up_free and arrow_outside_up.

PS: I don't mean to pile on the issues here, I'm happy to help and bring in a PR myself following consensus with some guidance. Also if there are alternative recommended routes with existing primitives that solves the above problem - open to trying those as well. Thanks for building and maintaining the library!

@danoneata
Copy link
Collaborator

Hello Jerin! I got some time to read your issue. For this particular example, I think the easiest way to achieve the correct layering is by drawing the arrows once you draw the other head components. That is, move this part

            dia = arrow_outside_up(
                dia,
                f"linear_{x}_" + str(i),
                "sdpa_" + str(i),
                left=True,
            )

in the head function:

        dia = vcat(center([sdpa, hcat([q, k, v], hspace / 2)]), vspace)
        for x in "vkq":
            dia = arrow_outside_up(
                dia,
                f"linear_{x}_{i}",
                f"sdpa_{i}",
                left=True,
            )
        dia = dia.center_xy().translate(dx, dy).fill_opacity(opacity)
        return dia

This change does seem to yield the desired output:
Screenshot 2023-09-11 at 15 52 10

Otherwise, even with the left flag it seems to me that we cannot reach a satisfactory result: the arrow from linear_v_2 to sdpa_2 is behind the linear_v_0 and linear_v_1 blocks (even though this is not immediately obvious due to the transparency of the linear blocks). So maybe a more general solution is needed? I recall that @srush also suggested a similar idea as a possible improvement, but I don't know if he had any approach in mind:

Render arrows in between two already plotted values in Z-space (Not sure if this is possible in a functional system)

@srush
Copy link
Collaborator

srush commented Sep 11, 2023

Do we know what Haskell diagrams does here? I can try to figure it out.

@danoneata
Copy link
Collaborator

Do we know what Haskell diagrams does here?

I don't recall seeing anything too similar in the Haskell codebase 🤔 Maybe the idea of delayed composition is related, as it allows reordering the components of a diagram after their creation, but I don't feel that this is much easier to use than simply creating the diagram in the "right" order from the get-go. Otherwise, for arrow connections, the connect-like functions also seem to use a predefined order (via atop).

@jerinphilip
Copy link
Author

jerinphilip commented Sep 12, 2023

Looks like my large-text created some confusion and more errors 😓. Please allow me to clarify.

  1. A flag that gives the user the ability to change the order of atop can be a backwards compatible change, with default being diagram + arrow and an option that switches to arrow + diagram. As for arrow_outside_up and arrow_outside_free, I was pointing at the left flag swapping order for atop. I argue this option on the connect family of functions should be a net improvement without breaking anything existing. (If there's some objective to keep close compatibility with haskell-diagrams, then this might be a problem).

  2. The remaining is me arguing for the use-case of such a flag. To clarify, I'm trying to use a switch to create Split -> Linear (branch) connect arrows. In this case, I want to plot the z-stack of Linears first and then vcat them properly with Split, afterwards connect. connect_outside or connect should work for me here, except in the current state (without the switch) arrows come on top when connects happen after the z-stack (Linear) and Split are declared (I see this wasn't clear in writing, apologies 😓).

image

I'm not sure if I follow your response:

That is, move this part ... in the head function

Both permalinks are same. Assuming you meant move this function to head. I see there's an opacity difference that indicates the arrows z-index (properly) between the former code and your suggestion and is clearly the better way to render. Thanks for pointing out. Let me know if I'm missing something.

@danoneata
Copy link
Collaborator

danoneata commented Sep 16, 2023

Thanks for the clarification, Jerin! I understand your motivation for the flag, but while it solves the "split" to "linear" use case, it doesn't seem to solve the "linear" to "scaled dot product attention" example.

Namely, we cannot achieve something like this
Screenshot 2023-09-17 at 00 16 35
with neither the original connect, which would yield
Screenshot 2023-09-17 at 00 17 16
nor with the reversed-order version, which would yield
Screenshot 2023-09-17 at 00 18 34

Ideally, I would prefer a solution that would address this sort of cases as well. I don't know, Sasha, if you have any opinion on this matter.

Otherwise, here is a hack that reverses the order of elements after a connect:

def connect_outside_reverse(*args, **kwargs):
    dia = connect_outside(*args, **kwargs)
    return dia.diagram2 + dia.diagram1

or written as combinator:

from chalk.core import Compose

def zswap(diagram):
    if isinstance(diagram, Compose):
        return diagram.diagram2 + diagram.diagram1
    else:
        return diagram

zswap(connect_outside(dia, f"bot {i}", f"top {i}"))

Both permalinks are same. Assuming you meant move this function to head.

Oops, yes, you are right! Here is the corresponding code that generated that figure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants