-
Notifications
You must be signed in to change notification settings - Fork 5
/
intro.tex
27 lines (21 loc) · 5.28 KB
/
intro.tex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
Formal descriptions of neural networks primarily adopt the notation of vectors and matrices from applied linear algebra~\citep{goodfellow2016deep}. When used to describe vector spaces, this notation is both concise and unambiguous. However, when applied to neural networks, these properties are lost. Consider the equation for attention as notated in the Transformer paper \citep{vaswani+:2017}:
\[ \text{Attention}(Q, K, V) = \left( \softmax \frac{QK^\top}{\sqrt{d_k}} \right) V. \]
The equation relates $Q$, $K$, and $V$ (for query, key, and value, respectively) as sequences of feature vectors, packed into possibly identically-sized matrices. While concise, this equation is ambiguous. Does the product $QK^\top$ sum over the sequence, or over the features? We know that it sums over columns, but there is not enough information to know what the columns represent. Is the softmax taken over the query sequence or the key sequence? The usual notation does not offer an answer. Perniciously, the implementation of an incorrect interpretation might still run without errors. With the addition of more axes, like multiple attention heads or multiple sentences in a minibatch, the notation becomes even more cumbersome.
We propose an alternative mathematical notation for tensors with \emph{named axes}.\footnote{%
We follow NumPy in using the term \emph{axis}. Other possible terms would be \emph{index}, \emph{dimension}, \emph{way}, or \emph{mode} \citep{tucker:1964}, but we felt that \emph{axis} had the least potential for confusion.} The notation has a formal underpinning, but is hopefully intuitive enough that machine learning researchers can understand it without much effort.
%
In named tensor notation, the above equation becomes
\begin{align*}
\text{Attention} \colon \mathbb{R}^{\key} \times \mathbb{R}^{\seq \times \key} \times \mathbb{R}^{\seq \times\val} &\rightarrow \mathbb{R}^{\val} \\
\text{Attention}(Q,K,V) = \left( \nfun{\seq}{softmax} \frac{Q \ndot{\key} K}{\sqrt{|\key|}} \right) \ndot{\seq} V.
\end{align*}
The type signature introduces three named axes: the $\key$ axis is for features of queries and keys, the $\val$ axis is for features of values, and the $\seq$ axis is for tokens in a sequence. (Please see \cref{sec:goodnames} for an explanation of our naming convention.) This notation makes the types of each input tensor explicit. Tensor $Q$ is a query vector that is compared with key vectors, so it has a $\key$ axis. Tensor $K$ is a sequence of key vectors, so it has $\seq$ and $\key$ axes. Tensor $V$ is a sequence of value vectors, so it has $\seq$ and $\val$ axes. Unlike with matrix notation, the reader is not required to remember whether $\seq$ corresponds to rows or columns in either of these tensors.
The function itself uses the named axes to precisely apply operations. The expression $Q \ndot{\key} K$ is a dot product over the $\key$ axis shared between $K$ and $Q$; there is no ambiguity about rows or columns. Similarly, the softmax function is annotated with the axis along which it is applied, removing any ambiguity or reliance on convention.
Furthermore, named tensor notation naturally extends to \textit{lifting} (also known as vectorizing and/or broadcasting) a function to tensors with more axes. For example, if instead of being a tensor with the single axis $\key$, $Q$ has three axes $\key$, $\seq$ and $\batch$ (corresponding to tokens of a sequence and examples in a minibatch, respectively) then the $\text{Attention}$ function works as written, acting on each example in a minibatch in parallel.
%That is, in this case $\text{Attention}(Q,K,V)$ equals the tensor $A$ with axes $\val$, $\seq$ and $\batch$, such that for every index $s$ of $\seq$ and $b$ of $\batch$, the corresponding element of $A$ is obtained by applying $\text{Attention}$ to the corresponding restriction of $Q$ (together with $K$ and $V$).
Similarly, we can also add a $\heads$ axis to the inputs to get multiple attention heads.
These additional axes are often elided in neural network papers, possibly avoiding notational complexity, but possibly also hiding critical model details.
\textbf{Our contributions.} This work proposes a \emph{mathematical notation} for named tensors and a fully specified \emph{semantic interpretation} for the notation.
Through examples, we demonstrate that this notation enables specifying machine learning models and operations in a succinct yet precise manner.
The need for named tensors has been recognized by several software packages, including xarray \citep{xarray}, Nexus \citep{chen2017typesafe}, tsalib \citep{tsalib}, axisarrays~\citep{axisarrays}, NamedTensor \citep{namedtensor}, PyTorch \citep{named-tensors}, Dex~\citep{dex}, JAX \citep{jax_xmap}, einops \citep{einops}, and torchdim \citep{torchdim}. While our notation is inspired by these efforts, our focus is on mathematical notation to be used in papers, whereas previous efforts have focused on code. Our hope is that our notation will be adopted by authors, leading to clearer, more replicable papers, and that this, in turn, will encourage more implementers to adopt named tensor libraries, leading to clearer, more correct code.
%The source code for this document can be found at \url{https://github.com/namedtensor/notation/}. We invite anyone to make comments on this proposal by submitting issues or pull requests on this repository.