Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature suggestion] Allow duplicate and non-unitary axes in indexing expressions #204

Open
simonalford42 opened this issue Aug 31, 2022 · 0 comments

Comments

@simonalford42
Copy link

I repeatedly find myself using duplicate and non-unitary axes in my code, before getting an error explaining that they are not allowed. For example,

c = 10
a = torch.arange(c * c * 4)
rearrange(a, '(c c 4) -> c c 4', c=c)

seems like it would be reasonable for it to "just work", by assuming a left to right ordering for duplicates and parsing numbers as numbers. Instead, I have to do rearrange(a, '(c1 c2 f) -> c1 c2 f', c1=c, c2=c, f=4)

I find it natural to give two different axes the same variable name when they are the same size and have the same semantic meaning. It is also natural to give an axis an integer name instead of a variable name. This comes up in math all the time: an $n \times n$ square matrix, or a $2 \times n \times n$ tensor. Disallowing causes me to give unintuitive letter names such as c1 c1 f in place of c c 4, created solely to get around this inability and otherwise having no semantic meaning (here the f is just the first letter of 'four').

If I had to guess, the reason for not enabling these features are:

  • duplicate axes enables ambiguity in how to order them. Ordering left to right by default seems like an intuitive approach.
  • non-unitary axes could require some extra parsing. But it seems reasonable.

Both of these seem easy enough to implement. I am not an einops super-user, but I took a look at some of the other operations to imagine whether these extensions are also feasible with reduce and repeat.

Rearrange

  • No matter the expression, it seems like left-to-right seems like a natural interpretation of the duplicate axes.
  • It seems like non-unitary axes would work fine.

Reduce

  • It is probably best to continue disallowing duplicates here. You could alternatively assume the first of the duplicates is reduced: reduce(a, 'c c -> c', 'max') would be the same as reduce(a, 'c1 c2 -> c2', 'max'), but this seems ambiguous to me.
  • It seems like non-unitary axes would work fine.

Repeat

  • I don't see an issue with duplicates here: repeat(a, 't -> t b b', b=3) would be the same as reduce(a, 't -> t b1 b2', b1=3, b2=3).
  • Here the possibility of non-unitary axes could change a lot of things. Instead of repeat(a, 'h w -> h w c', c=3), you could do repeat(a, 'h w -> h w 3'). I find this to be easier to read.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant