

# Decoder-Only Transformers

Notes on various aspects of Decoder-Only Transformers. Conventions are in App. A.

## Contents

|           |                                                              |           |
|-----------|--------------------------------------------------------------|-----------|
| <b>I</b>  | <b>Architecture</b>                                          | <b>2</b>  |
| <b>1</b>  | <b>Decoder-Only Fundamentals</b>                             | <b>3</b>  |
| 1.1       | Embedding Layer and Positional Encodings                     | 4         |
| 1.2       | Layer Norm                                                   | 4         |
| 1.3       | Causal Attention                                             | 5         |
| 1.4       | MLP                                                          | 8         |
| 1.5       | Language Model Head                                          | 8         |
| 1.6       | All Together                                                 | 9         |
| 1.7       | The Loss Function                                            | 10        |
| <b>2</b>  | <b>Architecture Variants</b>                                 | <b>11</b> |
| 2.1       | Multi-Query Attention                                        | 11        |
| 2.2       | Grouped Attention                                            | 12        |
| 2.3       | Parallel MLP and CausalAttention Layers                      | 12        |
| 2.4       | RoPE Embeddings                                              | 12        |
| 2.5       | Flash Attention                                              | 13        |
| <b>II</b> | <b>Training</b>                                              | <b>15</b> |
| <b>3</b>  | <b>Memory</b>                                                | <b>15</b> |
| 3.1       | No Sharding                                                  | 15        |
| 3.1.1     | Parameters, Gradients, Optimizer States, and Mixed Precision | 15        |
| 3.1.2     | Gradients                                                    | 17        |
| 3.1.3     | Activations                                                  | 17        |
| 3.2       | Tensor Parallelism                                           | 19        |
| 3.3       | Sequence Parallelism                                         | 23        |
| 3.4       | Pipeline Parallelism                                         | 25        |
| 3.5       | Case Study: Mixed-Precision GPT3                             | 25        |
| <b>4</b>  | <b>Training FLOPs</b>                                        | <b>26</b> |
| 4.1       | No Recomputation                                             | 27        |

|                                                           |           |
|-----------------------------------------------------------|-----------|
| <b>5 Training Time</b>                                    | <b>28</b> |
| <b>6 Scaling Laws</b>                                     | <b>29</b> |
| 6.1 Original Scaling Laws                                 | 30        |
| 6.2 Chinchilla Scaling Laws                               | 30        |
| <br>                                                      |           |
| <b>III Inference</b>                                      | <b>31</b> |
| <b>7 Basics and Problems</b>                              | <b>32</b> |
| <b>8 Generation Strategies</b>                            | <b>32</b> |
| 8.1 Greedy                                                | 32        |
| 8.2 Simple Sampling: Temperature, Top- $k$ , and Top- $p$ | 33        |
| 8.3 Beam Search                                           | 33        |
| <b>9 The Bare Minimum and the kv-Cache</b>                | <b>33</b> |
| <b>10 Basic Memory, FLOPs, Communication, and Latency</b> | <b>35</b> |
| <b>11 Case Study: Falcon-40B</b>                          | <b>36</b> |
| <b>A Conventions and Notation</b>                         | <b>36</b> |
| <b>B Collective Communications</b>                        | <b>38</b> |
| <b>C Hardware</b>                                         | <b>39</b> |
| C.1 NVIDIA GPU Architecture                               | 39        |
| C.2 CUDA Programming Model                                | 40        |
| C.3 NVIDIA GPU Stats                                      | 40        |
| <b>D Compute-bound vs Memory-bound</b>                    | <b>41</b> |
| D.1 Matrix-Multiplications vs. Element-wise Operations    | 41        |
| D.2 Training vs. Inference                                | 41        |
| D.3 Intra- and Inter-Node Communication                   | 42        |
| <b>E Batch Size, Compute, and Training Time</b>           | <b>42</b> |
| <b>F Cheat Sheet</b>                                      | <b>44</b> |
| <b>G TODO</b>                                             | <b>45</b> |

# Part I

## Architecture

### 1 Decoder-Only Fundamentals

The Transformers architecture [1], which dominates Natural Language Processing (NLP) as of July 2023, is a relatively simple architecture. There are various flavors and variants of Tranformers, but focus here on the decoder-only versions which underlie the GPT models [2–4].

The full decoder-only architecture can be seen in Fig. 1. The parameters which define the network can be found in App. A.



**Figure 1.** The full transformers architecture. Diagram taken from [5]

At a high level, decoder-only transformers take in an ordered series of word-like objects, called tokens, and are trained to predict the next token in the sequence. Given some initial text, transformers can be used to give a prediction for the likelihood of any possible continuation of that text. An outline of the mechanics<sup>1</sup>:

1. Raw text is **tokenized** and turned into a series of integers<sup>2</sup> whose values lie in  $\text{range}(V)$ , with  $V$  the vocabulary size.
2. The tokenized text is chunked and turned into  $(B, S)$ -shaped (batch size and sequence length, respectively) integer tensors,  $x_{bs}$ .
3. The **embedding layer** converts the integer tensors into continuous representations of shape  $(B, S, D)$ ,  $z_{bsd}$ , with  $D$  the size of the hidden dimension. **Positional encodings** have also been added to the tensor at this stage to help the architecture understand the relative ordering of the text.
4. The  $z_{bsd}$  tensors pass through a series of transformer blocks, each of which has two primary components:

<sup>1</sup>This describes the vanilla architecture; almost every component is modified in the available variants.

<sup>2</sup>There are about 1.3 tokens per word, on average.

- (a) In the **attention** sub-block, components of  $z_{bsd}$  at different positions ( $s$ -values) interact with each other, resulting in another (B, S, D)-shaped tensor,  $z'_{bsd}$ .
- (b) In the **MLP** block, each position in  $z'_{bsd}$  is processed independently and in parallel by a two-layer feed-forward network, resulting once more in a (B, S, D)-shaped tensor.

Importantly, there are **residual connections** around each of these<sup>3</sup> (the arrows in Fig. 1), meaning that the output of each block is added back to its original input.

5. Finally, we convert the (B, S, D)-shaped tensors to (B, S, V)-shaped ones,  $y_{bsv}$ . This is the role of the **language model head** (which is often just the embedding layer used in an inverse manner.)
6. The  $y_{bsv}$  predict what the next token will be, i.e.  $x_{bs+1}$ , having seen the **context** of the first  $s$  tokens in the sequence. Specifically, removing the batch index for simplicity, a **Softmax** of  $y_{sv}$  gives the conditional probability  $p_{sv} = P(t_{s+1}|t_s \dots t_0)$  for the indicated series of tokens. Because of the chain rule of probability, these individual probabilities can be combined to form the probability that any sequence of tokens follows a given initial seed<sup>4</sup>.

Each batch (the  $b$ -index) is processed independently. We omitted **LayerNorm** and **Dropout** layers above, as well as the causal mask; these will be covered below as we step through the architecture in more detail.

## 1.1 Embedding Layer and Positional Encodings

The **embedding** layer is just a simple look up table: each of the **range(V)** indices in the vocabulary is mapped to a  $D$ -dimensional vector via a large (V, D)-shaped table/matrix. This layer maps  $x_{bs} \rightarrow z_{bsd}$ . In **torch**, this is an `nn.Embedding(V, D)` instance.

To each item in a batch, we add identical **positional encodings** to the vectors above with the goal of adding fixed, position-dependent correlations in the sequence dimension which will hopefully make it easier for the architecture to pick up on the relative positions of the inputs<sup>5</sup>. This layer maps  $z_{bsd} \leftarrow z_{bsd} + p_{sd}$ , with  $p_{sd}$  the positional encoding tensor.

The above components require  $(V + S)D \approx VD$  parameters per model.

## 1.2 Layer Norm

The original transformers paper [1] put **LayerNorm** instances after the **attention** and **MLP** blocks, but now it is common [6] to put them before these blocks<sup>6</sup>.

---

<sup>3</sup>This gives rise to the concept of the **residual stream** which each transformer block reads from and writes back to repeatedly.

<sup>4</sup>In more detail, these probabilities are created by products:  $P(t_{s+n} \dots t_{s+1}|t_s \dots t_0) = P(t_{s+n}|t_{s+n-1} \dots t_s \dots t_0) \times \dots \times P(t_{s+1}|t_s \dots t_0) \dots$

<sup>5</sup>Positional encodings and the causal mask are the only components in the transformers architecture which carry weights with a dimension of size  $S$ ; i.e. they are the only parts that have explicit sequence-length dependence. A related though experiment: you can convince yourself that if the inputs  $z_{bsd}$  were just random noise, the transformers architecture would not be able to predict the  $s$ -index of each such input in the absence of positional encodings.

<sup>6</sup>Which makes intuitive sense for the purposes of stabilizing the matrix multiplications in the blocks

The `LayerNorm` operations acts over the sequence dimension. Spelling it out, given the input tensor  $z_{bsd}$  whose mean and variance over the  $s$ -index are  $\mu_{bd}$  and  $\sigma_{bd}$ , respectively, the `LayerNorm` output is

$$z_{bsd} \leftarrow \left( \frac{z_{bsd} - \mu_{bd}}{\sigma_{bd}} \right) \times \gamma_d + \beta_d \equiv \text{LayerNorm}_s z_{bsd} \quad (1.1)$$

where  $\gamma_d, \beta_d$  are the trainable scale and bias parameters. In `torch`, this is a `nn.LayerNorm(D)` instance. Since there are two `LayerNorm` instances in each transformer block, these components require  $2D$  parameters per layer.

We will continue discussing `LayerNorm` instances in what follows in order to adhere to the usual construction and to discuss methods like sequence-parallelism in their original form (see Sec. 3.3), but note: the data-independent `LayerNorm` transformations are completely redundant when immediately followed by a `Linear` layer, since both act linearly on their inputs and `Linear` is already the most general data-independent linear transformation. Explicitly, the  $\gamma_d, \beta_d$  parameters can be absorbed into the linear parameters as in

$$(z_{bsd}\gamma_d + \beta_d) W_{dd'} + b_{d'} = z_{bsd}W'_{dd'} + b'_{d'} , \quad W'_{dd'} \equiv \gamma_d W_{dd'} , \quad b'_{d'} \equiv b_{d'} + \beta_d W_{dd'} . \quad (1.2)$$

That is, these transformations can be equivalently performed by the weight matrix and bias (if included) in `Linear` layer<sup>7</sup>.

### 1.3 Causal Attention

**Causal attention** is the most complex layer. It features  $A$  sets of weight matrices<sup>8</sup>  $Q_{dea}, K_{dea}, V_{dea}$  where  $a \in \{0, \dots, A-1\}$  and  $e \in \{0, \dots, D/A\}$ , where  $D$  is assumed perfectly divisible by  $A$ . From these, we form three different vectors:

$$q_{bsea} = z_{bsd}Q_{dea} , \quad k_{bsea} = z_{bsd}K_{dea} , \quad v_{bsea} = z_{bsd}V_{dea} \quad (1.3)$$

These are the **query, key, and value** tensors, respectively<sup>9</sup>.

Using the above tensors, we will then build up an **attention map**  $w_{bss'a}$  which corresponds to how much attention the token at position  $s$  pays to the token at position  $s'$ . Because we have the goal of predicting the next token in the sequence, we need these weights to be causal: the final prediction  $y_{bsv}$  should only have access to information propagated from positions  $x_{bs'v}$  with  $s' \leq s$ . This corresponds to the condition that  $w_{bss'a} = 0$  if  $s' > s$ .

These weights come from `Softmax`-ed attention scores, which are just a normalized dot-product over the hidden dimension:

$$w_{bss'da} = \text{Softmax}_{s'} \left( m_{ss'} + \frac{q_{bse}k_{bs'ea}}{\sqrt{D/A}} \right) , \quad \text{s.t.} \quad \sum_{s'} w_{bdss'a} = 1 \quad (1.4)$$

---

<sup>7</sup>Note the importance of data-independence here: the data-dependent subtraction of the mean and division by the standard deviation cannot be absorbed for all elements in a batch. Note that because the usual training algorithms are not covariant under parameter redefinitions, the above unfortunately does not imply that removing `LayerNorms` will have no effect on training dynamics.

<sup>8</sup>There are also bias terms, but we will often neglect to write them explicitly or account for their (negligible) parameter count.

<sup>9</sup>There are of course many variants of the architecture and one variant which is popular in Summer 2023 is multi-query attention [7] in which all heads share *the same* key and value vectors and only the query changes across heads, as this greatly reduces inference costs.

The tensor  $m_{ss'}$  is the causal mask which zeroes out the relevant attention map components above

$$m_{ss'} = \begin{cases} 0 & s \leq s' \\ -\infty & = s > s' \end{cases},$$

forcing  $w_{bss'da} = 0$  for  $s > s'$ . In other words, the causal mask ensures that a given tensor, say  $z_{bsd}$ , only has dependence on other tensors whose sequence index, say  $s'$ , obeys  $s' \leq s$ . This is crucial for inference-time optimizations, in particular the use of the **kv-cache** in which key-value pairs do not need to be re-computed.

The  $\sqrt{D/A}$  normalization is motivated by demanding that the variance of the **Softmax** argument be 1 at initialization, assuming that other components have been configured so that the query and key components are i.i.d. from a Gaussian normal distribution <sup>10</sup>.

The weights above are then passed through a dropout layer and used to re-weigh the **value** vectors and form the tensors

$$y_{bsea} = \text{Drop}(w_{bdss'a}) v_{bs'ea} \quad (1.5)$$

and these (B, S, D/A, A)-shaped tensors are then concatenated along the  $e$ -direction to re-form a (B, S, D)-shaped tensor  $u_{bsd}$

$$u_{bsd} = y_{bs(ea)} \quad (1.6)$$

in **einops**-like notation for concatenation. Finally, another weight matrix  $O_{d'd}$  and dropout layer transform the output once again to get the final output

$$z_{bsd} = \text{Drop}(u_{bsd'} O_{d'd}) . \quad (1.7)$$

For completeness, the entire operation in condensed notation with indices left implicit is:

$$z \leftarrow \text{Drop} \left( \text{Concat} \left( \text{Drop} \left( \text{Softmax} \left( \frac{(z \cdot Q_a) \cdot (z \cdot K_a)}{\sqrt{D/A}} \right) \right) \cdot z \cdot V_a \right) \cdot O \right) \quad (1.8)$$

where all of the dot-products are over feature dimensions (those of size  $D$  or  $D/A$ ).

Below is pedagogical<sup>11</sup> sample code for such a **CausalAttention** layer<sup>12</sup>:

```

8  class CausalAttention(nn.Module):
9      def __init__(self,
10          block_size=K,
11          dropout=0.1,
12          hidden_dim=D,
13          num_attn_heads=A,
```

---

<sup>10</sup>However, in [8] it is instead argued that no square root should be taken in order to maximize the speed of learning via SGD.

<sup>11</sup>The code is written for clarity, not speed. An example optimization missing here: there is no need to form separate  $Q_a, K_a, V_a$  **Linear** layers, one large layer which is later chunked is more efficient

<sup>12</sup>When using sequence-parallelism, it will be more natural to separate out the final **Dropout** layer and combine it with the subsequent **LayerNorm**, as they are sharded together; see Sec. 3.3. The same is true for the **MLP** layer below.

```

15  ):
16      super().__init__()
17      self.block_size = block_size
18      self.dropout = dropout
19      self.hidden_dim = hidden_dim
20      self.num_attn_heads = num_attn_heads
21
22      self.head_dim, remainder = divmod(hidden_dim, num_attn_heads)
23      assert not remainder, "num_attn_heads must divide hidden_dim evenly"
24
25      self.Q = nn.ModuleList(
26          [nn.Linear(hidden_dim, self.head_dim) for _ in range(num_attn_heads)]
27      )
28      self.K = nn.ModuleList(
29          [nn.Linear(hidden_dim, self.head_dim) for _ in range(num_attn_heads)]
30      )
31      self.V = nn.ModuleList(
32          [nn.Linear(hidden_dim, self.head_dim) for _ in range(num_attn_heads)]
33      )
34      self.O = nn.Linear(hidden_dim, hidden_dim)
35
36      self.attn_dropout = nn.Dropout(dropout)
37      self.out_dropout = nn.Dropout(dropout)
38      self.register_buffer(
39          "causal_mask",
40          torch.tril(torch.ones(block_size, block_size)[None]),
41      )
42
43  def get_qkv(self, inputs):
44      queries = [q(inputs) for q in self.Q]
45      keys = [k(inputs) for k in self.K]
46      values = [v(inputs) for v in self.V]
47      return queries, keys, values
48
49  def get_attn_maps(self, queries, keys):
50      S = queries[0].shape[1]
51      norm = math.sqrt(self.head_dim)
52      non_causal_attn_scores = [(q @ k.transpose(-2, -1)) / norm for q, k in zip(queries, keys)]
53      # Note: this mask shape is a bit of a hack to make generation from the KV cache work without
54      # specifying an extra boolean. When queries and keys have different sequence lengths and the
55      # queries are of seq_len == 1, p the query attends to all of the keys; effectively there is
56      # no mask at all.
57      causal_attn_scores = [
58          a.masked_fill(self.causal_mask[:, :S, :S] == 0, float("-inf"))
59          for a in non_causal_attn_scores
60      ]
61      attn_maps = [a.softmax(dim=-1) for a in causal_attn_scores]
62      return attn_maps
63
64  def forward(self, inputs):
65      queries, keys, values = self.get_qkv(inputs)
66      attn_maps = self.get_attn_maps(queries, keys)
67      weighted_values = torch.cat(
68          [self.attn_dropout(a) @ v for a, v in zip(attn_maps, values)], dim=-1
69      )

```

```

70     z = self.0(weighted_values)
71     z = self.out_dropout(z)
72     return z

```

The parameter count is dominated by the weight matrices which carry  $4D^2$  total parameters per layer.

## 1.4 MLP

The feed-forward network is straightforward and corresponds to

$$z_{bsd} \leftarrow \text{Drop}(\phi(z_{bsd'} W_{d'e}^0) W_{ed}^1) \quad (1.9)$$

where  $W^0$  and  $W^1$  are  $(D, ED)$ - and  $(ED, D)$ -shaped matrices, respectively (see App. A for notation) and  $\phi$  is a non-linearity<sup>13</sup>. In code, where we again separate out the last Dropout layer as we did in in Sec. 1.3.

```

6  class MLP(nn.Module):
7      def __init__(self,
8          hidden_dim=D,
9          expansion_factor=E,
10         dropout=0.1,
11      ):
12          super().__init__()
13          self.hidden_dim = hidden_dim
14          self.expansion_factor = expansion_factor
15          self.dropout = dropout
16
17          linear_1 = nn.Linear(hidden_dim, expansion_factor * hidden_dim)
18          linear_2 = nn.Linear(expansion_factor * hidden_dim, hidden_dim)
19          gelu = nn.GELU()
20          self.layers = nn.Sequential(linear_1, gelu, linear_2)
21          self.dropout = nn.Dropout(dropout)
22
23
24      def forward(self, inputs):
25          z = self.layers(inputs)
26          z = self.dropout(z)
27          return z

```

This block requires  $2ED^2$  parameters per layer, only counting the contribution from weights.

## 1.5 Language Model Head

The layer which converts the  $(B, S, D)$ -shaped outputs,  $z_{bsd}$ , to  $(B, S, V)$ -shaped predictions over the vocabulary,  $y_{bsv}$ , is the **Language Model Head**. It is a linear layer, whose weights are usually tied to be exactly those of the initial embedding layer of Sec. 1.1.

---

<sup>13</sup>The GeLU non-linearity is common.

## 1.6 All Together

It is then relatively straightforward to tie everything together. In code, we can first create a transformer block like

```
8 class TransformerBlock(nn.Module):
9     def __init__(self,
10         block_size=K,
11         dropout=0.1,
12         expansion_factor=E,
13         hidden_dim=D,
14         num_attn_heads=A,
15         num_layers=L,
16         vocab_size=V,
17     ):
18         super().__init__()
19         self.block_size = block_size
20         self.dropout = dropout
21         self.expansion_factor = expansion_factor
22         self.hidden_dim = hidden_dim
23         self.num_attn_heads = num_attn_heads
24         self.num_layers = num_layers
25         self.vocab_size = vocab_size
26
27         self.attn_ln = nn.LayerNorm(hidden_dim)
28         self.attn = CausalAttention(
29             block_size=block_size,
30             dropout=dropout,
31             hidden_dim=hidden_dim,
32             num_attn_heads=num_attn_heads,
33         )
34
35         self.mlp_ln = nn.LayerNorm(hidden_dim)
36         self.mlp = MLP(hidden_dim, expansion_factor, dropout)
37
38     def forward(self, inputs):
39         z_attn = self.attn_ln(inputs)
40         z_attn = self.attn(z_attn) + inputs
41
42         z_mlp = self.mlp_ln(z_attn)
43         z_mlp = self.mlp(z_mlp) + z_attn
44
45         return z_mlp
```

which corresponds to the schematic function

$$z \leftarrow z + \text{MLP}(\text{LayerNorm}(z + \text{CausalAttention}(\text{LayerNorm}(z)))) , \quad (1.10)$$

indices suppressed.

And then the entire architecture:

```
7 class DecoderOnly(nn.Module):
8     def __init__(self,
```

```

10     block_size=K,
11     dropout=0.1,
12     expansion_factor=E,
13     hidden_dim=D,
14     num_attn_heads=A,
15     num_layers=L,
16     vocab_size=V,
17   ):
18     super().__init__()
19     self.block_size = block_size
20     self.dropout = dropout
21     self.expansion_factor = expansion_factor
22     self.hidden_dim = hidden_dim
23     self.num_attn_heads = num_attn_heads
24     self.num_layers = num_layers
25     self.vocab_size = vocab_size
26
27     self.embedding = nn.Embedding(vocab_size, hidden_dim)
28     self.pos_encoding = nn.Parameter(torch.randn(1, block_size, hidden_dim))
29     self.drop = nn.Dropout(dropout)
30     self.trans_blocks = nn.ModuleList(
31       [
32         TransformerBlock(
33           block_size=block_size,
34           dropout=dropout,
35           expansion_factor=expansion_factor,
36           hidden_dim=hidden_dim,
37           num_attn_heads=num_attn_heads,
38           num_layers=num_layers,
39           vocab_size=vocab_size,
40         )
41         for _ in range(num_layers)
42       ]
43     )
44     self.final_ln = nn.LayerNorm(hidden_dim)
45     self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
46     self.lm_head.weight = self.embedding.weight # Weight tying.
47
48   def forward(self, inputs):
49     S = inputs.shape[1]
50     z = self.embedding(inputs) + self.pos_encoding[:, :S]
51     z = self.drop(z)
52     for block in self.trans_blocks:
53       z = block(z)
54     z = self.final_ln(z)
55     z = self.lm_head(z)
56     return z

```

## 1.7 The Loss Function

The last necessary component is the loss function. The training loop data is canonically the  $(B, K)$ -shaped<sup>14</sup> token inputs ( $x_{bs}$ ) along with their shifted-by-one relatives  $y_{bs}$  where  $x[:, s + 1] == y[:, s]$ .

---

<sup>14</sup>K is the block size, the maximum sequence-length for the model. See App. A.

The  $(B, K, V)$ -shaped outputs ( $z_{bsv}$ ) of the `DecoderOnly` network are treated as the logits which predict the value of the next token, given the present context:

$$p(x_{b(s+1)} = v | x_{bs}, x_{b(s-1)}, \dots, x_{b0}) = \text{Softmax}_v z_{bsv} \quad (1.11)$$

and so the model is trained using the usual cross-entropy/maximum-likelihood loss

$$\begin{aligned} \mathcal{L} &= -\frac{1}{BK} \sum_{b,s} \ln p(x_{b(s+1)} = v | x_{bs}, x_{b(s-1)}, \dots, x_{b0}) \\ &= \frac{-1}{BK} \sum_{b,s} \text{Softmax}_v z_{bsv} \Big|_{v=y_{bs}}. \end{aligned} \quad (1.12)$$

Note that the losses for all possible context lengths are included in the sum<sup>15</sup>.

In `torch` code, the loss computation might look like the following (using fake data):

```

7  def test_loss():
8      model = DecoderOnly(
9          num_attn_heads=A,
10         block_size=K,
11         dropout=0.1,
12         expansion_factor=E,
13         hidden_dim=D,
14         num_layers=L,
15         vocab_size=V,
16     )
17     tokens = torch.randint(model.vocab_size, size=(B, model.block_size + 1))
18     inputs, targets = tokens[:, :-1], tokens[:, 1:]
19     outputs = model(inputs)
20     outputs_flat, targets_flat = outputs.reshape(-1, outputs.shape[-1]), targets.reshape(-1)
21     loss = F.cross_entropy(outputs_flat, targets_flat)
22     assert loss

```

## 2 Architecture Variants

There are, of course, many variants on the basic architecture. Some particularly important ones are summarized here.

### 2.1 Multi-Query Attention

In [7], the  $A$  different key and value matrices are replaced by a single matrix each, while  $A$  different query-heads remain. The mechanisms are otherwise unchanged: where there were previously distinct key and value tensors used across different heads, we just use the same tensors everywhere. This is **Multi-Query Attention** (MQA).

The primary reason for multi-query attention is that it vastly reduces the size of the kv-cache (see Sec. 9) during inference time, decreasing the memory-burden of the cache by a factor of  $A$ . This strategy also reduces activation memory during training, but that is more of a side-effect.

---

<sup>15</sup>In Natural Language Processing (NLP), the **perplexity** is often reported instead of the loss, which is just the exponential of the loss, a geometric-mean over the gold-answer probabilities:  $\text{perplexity} = e^{\mathcal{L}} = \left( \prod_{b,s} p(x_{b(s+1)} = v | x_{bs}, x_{b(s-1)}, \dots, x_{b0}) \right)^{\frac{1}{BK}}$ .

## 2.2 Grouped Attention

**Grouped Query Attention** (GQA) [9] is the natural extension of multi-query-attention to using  $1 < G < A$  matrices for key and value generation. Each of the  $G$  different keys gets matched up with  $A/G$  heads (nice divisibility assumed)<sup>16</sup>.

## 2.3 Parallel MLP and CausalAttention Layers

Rather than first pass inputs into the `CausalAttention` layer of each block, and then pass those outputs on to `MLP` in series, `GPT-J-6B` instead processes the `LayerNorm` outputs in *parallel*. That is, instead of something like

$$z \leftarrow z + \text{MLP}(\text{LayerNorm}(z + \text{CausalAttention}(z))) \quad (2.1)$$

we instead have<sup>17</sup>

$$z \leftarrow z + \text{MLP}(z) + \text{CausalAttention}(z) . \quad (2.2)$$

Note that a `LayerNorm` instance is also removed.

## 2.4 RoPE Embeddings

A shortcoming of traditional embeddings  $x_{bsd} \rightarrow x_{bsd} + p_{sd}$  is that they do not generalize very well: a model trained on such embeddings with a maximum sequence length  $K$  will do very poorly when evaluated on longer sequences. RoPE (Rotary Position Embedding) [12] and variants thereof can extend the viable context length by more clever mechanisms with stronger implicit biases.

RoPE and its variants can be motivated by a few natural conditions. Given the queries and keys for an input  $q_{sd}, k_{sd}$  (suppressing batch indices), the corresponding attention scores computation  $a_{ss'}(q_s, k_{s'})$  should reasonably satisfy the below:

1. The attention score should only depend on the position indices  $s, s'$  through their difference  $s - s'$ , i.e., through their relative distance to each other.
2. The score computation should still be efficient, i.e., based on matrix-multiplications.
3. The operation should preserve the scale of the intermediate representations and attention scores, in order to avoid issues with standard normalization.

These conditions suggest a very natural family of solutions: rotation of the queries by some fixed element of  $SO(d)$  with a generator proportional to the position index, and rotation of keys by the conjugate element,

$$q_{sd} \longrightarrow \left[ e^{is\hat{n} \cdot T} \right]_{dd'} q_{sd'} \equiv R(s)_{dd'} q_{sd'}$$

---

<sup>16</sup>Llama-2 [10] uses GQA with  $G = 8$ , seemingly chosen so that each group can be sharded and put on its own GPU within a standard 8-GPU node.

<sup>17</sup>This alternative layer was also used in PaLM [11] where it was claimed that this formulation is  $\sim 15\%$  faster due to the ability to fuse the `MLP` and `CausalAttention` matrix multiplies together (though this is not done in the `GPT-J-6B` repo above).

$$k_{sd} \longrightarrow \left[ e^{-is\hat{n}\cdot T} \right]_{dd'} k_{sd'} \equiv R(s)_{dd'}^\dagger . \quad (2.3)$$

Performing the above computation with a dense element of  $SO(D)$  is infeasible, as it would require a new dense matrix-multiply by a unique  $D \times D$  matrix at each sequence position<sup>18</sup>

In the original RoPE paper, the rotation  $\hat{n}$  was chosen such that the matrices are  $2 \times 2$  block-diagonal with the entries of the form<sup>19</sup>

$$R(s)_{[d:d+2][d:d+2]} = \begin{pmatrix} \cos(s\theta_d) & -\sin(s\theta_d) \\ \sin(s\theta_d) & \cos(s\theta_d) \end{pmatrix} \quad (2.4)$$

where

$$\theta_d = 10^{-8d/D} . \quad (2.5)$$

## 2.5 Flash Attention

Flash Attention [13, 14] optimizes the self attention computation by never materializing the  $\mathcal{O}(S^2)$  attention scores in off-chip memory. This increases the arithmetic intensity of the computation and reduces the activation memory required, at the expense of needing recomputation in the backwards pass.

The central idea is to decompose the attention computation in the following way. Dropping the batch index, let  $q_{sd}, k_{sd}, v_{sd}$  be the queries, keys, and values, and  $z_{sd}$  be the final output. Splitting into attention heads as in  $q_{sd} = q_{s(ah)} \longrightarrow q_{sah}$  and similar the computation is

$$z_{sah} = \text{Softmax}_{s'} (q_{sah'} k_{s'ah'}) v_{s'ah} \quad (2.6)$$

which is the concatenated as  $z_{s(d)} \longrightarrow z_{sd}$  to get the result.

The  $\mathcal{O}(AS^2)$  attention scores  $\text{Softmax}_{s'} (q_{sah'} k_{s'ah'})$  are expensive to shuttle back and forth from off-chip to on-chip memory and are a costly activation memory component. To this end, Flash Attention functions by further splitting  $q, z$  as in  $q_{sd} = q_{(ir)(ah)} \longrightarrow q_{airh}$ , and similar for  $z$ , and  $k, v$  where  $k_{sd} = k_{(jc)(ah)} \longrightarrow k_{ajch}$ , and similar for  $v$ , so that the above reads

$$\begin{aligned} z_{airh} &= \text{Softmax}_{j'c'} (q_{airh'} k_{j'c'ah'}) v_{j'c'ah} \\ &= P_{airj'c'} v_{j'c'ah} \end{aligned} \quad (2.7)$$

where range of the new indices is dictated by the available on-chip memory and is analyzed below.

The algorithm functions by performing the sums over  $j, c', h'$  on-chip. We will jump right to the second version of the algorithm. The algorithm operates on each attention head independently and we omit the head index  $a$  (and batch index  $b$ ) for clarity, and also change the notation from [14] somewhat. In addition to the final outputs  $z_{irh}$ , the algorithm saves the  $S$ -shaped Softmax denominators  $\ell_{ir}$  which are related to the softmax probabilities as in

$$P_{irjc} = \frac{\exp(q_{irh} k_{jch})}{\ell_{ir}} . \quad (2.8)$$

---

<sup>18</sup>For one, the  $\mathcal{O}(SD^2)$  memory cost to store the matrices would be prohibitive. The FLOPs cost is only  $2BSD^2$ , the same as for other matrix multiplies, but because different matrices are needed at position, these FLOPs would be much more GPU memory-bandwidth intensive.

<sup>19</sup>If  $D$  isn't even, the vectors are padded by an extra zero.

Because only the Softmax outputs are needed to compute the Softmax derivative (3.10), this is sufficient data to perform the backwards pass, at the cost of a little re-computation in the form of  $q, k$  matrix-multiples.

The entire forward-pass algorithm is as follows. We capitalize tensors which only exist on-chip and leave tensors which come from off-chip lower-cased.

---

**Algorithm 1** Flash Attention Version Two (Forward)

---

```

1: Initialize off-chip tensors  $z_{irh}, \ell_{ir}$  to zeros
2: for  $i \in \dots$  do
3:   Move  $q_{irh}$  on-chip, instantiate  $Z_{ijrh}, L_{ijr}$  to zeros and  $M_{ijr}$  to  $-\infty$  on-chip.
4:   for  $j \in \dots$  do
5:     Move  $k_{jch}, v_{jch}$  on-chip
6:      $S_{ijrc} \leftarrow q_{irh}k_{jch}$                                  $\triangleright S_{ijrc}$  is a temp, only needed for one  $i, j$  iter
7:      $M_{ijr} \leftarrow \max(M_{i(j-1)r}, \max_c S_{ijrc})$ 
8:      $\tilde{P}_{ijrc} \leftarrow \exp(S_{ijrc} - M_{ijr})$ 
9:      $L_{ijr} \leftarrow e^{M_{i(j-1)r} - M_{ijr}} L_{ijr} + \sum_c \tilde{P}_{ijrc}$        $\triangleright$  Correct, update denominators
10:     $Z_{ijrh} \leftarrow e^{M_{i(j-1)r} - M_{ijr}} Z_{ijrh} + \tilde{P}_{ijrc}v_{jch}$        $\triangleright$  Correct, update numerators
11:     $z_{irh} \leftarrow \frac{1}{L_{i(-1)r}} Z_{i(-1)rh}$        $\triangleright$  Divide  $z$  by last  $j$  value of  $L_{ijc}$ 
12:     $\ell_{ir} \leftarrow M_{i(-1)r} + \ln L_{i(-1)r}$        $\triangleright$  Store sufficient statistics
13:    Write  $z_{irh}, \ell_{ir}$  to off-chip memory

```

---

Essentially, one just computes partial values for each term in the softmax sum, correcting old results and adding new terms to the sum with each iteration. Whereas vanilla attention requires  $\mathcal{O}(S^2 + DS)$  memory accesses per attention head, flash attention requires fewer. Let  $M$  be the amount of on-chip memory in bytes and let  $R$  be the range of  $r \in \{0, \dots, R-1\}$  and similar for other indices. Being able to fit all of the above tensors in memory requires  $CH + RH + RC \lesssim \mathcal{O}(M)$  most of the memory accesses above come from the  $q, z$  accesses in the inner loop, which access  $\mathcal{O}\left(\frac{HS^2}{C}\right) \lesssim \mathcal{O}\left(\frac{H^2S^2}{M}\right)$  bytes of memory, using the inequality on memory. Since  $M \sim 10^5$  bytes on 2023 GPUs, this is a win for typical values of  $H \sim 10^2$ .

Flash attention is also a big win for activation memory: a naive algorithm has a  $\mathcal{O}(ABS^2)$  per-layer contribution to activation memory due to needing to save the attention weights, but these are discarded and re-computed for flash attention. The only additional memory cost comes from the  $\mathcal{O}(ABS)$  elements in the  $\ell_{abs}$  statistics, which are dominated by the  $\mathcal{O}(BSD)$  costs from needing to save inputs, and hence negligible.

# Part II

# Training

## 3 Memory

In this section we summarize the train-time memory costs of Transformers under various training strategies<sup>20</sup>.

The memory cost is much more than simply the cost of the model parameters. Significant factors include:

- Optimizer states, like those of Adam
- Mixed precision training costs, due to keeping multiple model copies.
- Gradients
- Activation memory<sup>21</sup>, needed for backpropagation.

Because the activation counting is a little more involved, it is in its own section.

### Essentials

Memory costs count the elements of all tensors in some fashion, both from model parameters and intermediate representations. The gradient and optimizer state costs scale with the former quantity:  $\mathcal{O}(N_{\text{params}}) \sim \mathcal{O}(LD^2)$ , only counting the dominant contributions from weight matrices. Activation memory scales with the latter, which for a (B, S, D)-shaped input gives  $\mathcal{O}(BDLS)$  contributions from tensors which preserve the input shape, as well as  $\mathcal{O}(ABLS^2)$  factors from attention matrices.

### 3.1 No Sharding

Start with the simplest case where there is no sharding of the model states. Handling the different parallelism strategies later will be relatively straightforward, as it involves inserting just a few factors here and there.

#### 3.1.1 Parameters, Gradients, Optimizer States, and Mixed Precision

Memory from the bare parameter cost, gradients, and optimizer states are fixed costs independent of batch size and sequence-length (unlike activation memory), so we discuss them all together here. The parameter and optimizer costs are also sensitive to whether or not mixed-precision is used, hence we also address that topic, briefly. We will assume the use of Adam<sup>22</sup> throughout, for simplicity

<sup>20</sup>A nice related blog post is [here](#).

<sup>21</sup>Activations refers to any intermediate value which needs to be cached in order to compute backpropagation. We will be conservative and assume that the inputs of all operations need to be stored, though in practice gradient checkpointing and recomputation allow one to trade caching for redundant compute. In particular, flash attention [13] makes use of this strategy.

<sup>22</sup>Which stores two different running averages per-model parameter.

and concreteness. It will sometimes be useful below to let  $p$  to denote the precision in bytes that any given element is stored in, so `torch.float32` corresponds to  $p = 4$ , for instance. Ultimately, we primarily consider vanilla training in  $p = 4$  precision and `torch.float32/torch.float16` ( $p = 4/p = 2$ ) mixed-precision, other, increasingly popular variants to exist, so we keep the precision variable where we can.

Without mixed precision, the total cost of the `torch.float32` ( $p = 4$  bytes) model and optimizer states in bytes is then

$$M_{\text{model}} = 4N_{\text{params}}, \quad M_{\text{optim}} = 8N_{\text{params}} \quad (\text{no mixed precision}, p = 4) \quad (3.1)$$

where, from the previous section, the pure parameter-count of the decoder-only Transformers architecture is

$$N_{\text{params}} \approx (4 + 2E)L D^2 \times \left( 1 + \mathcal{O}\left(\frac{V}{DL}\right) + \mathcal{O}\left(\frac{1}{D}\right) \right). \quad (3.2)$$

where the first term comes from the `TransformerBlock` weight matrices<sup>23</sup>, the first omitted subleading correction term is the embedding matrix, and the last comes from biases, `LayerNorm` instances, and other negligible factors. The optimizer states cost double the model itself.

The situation is more complicated when mixed-precision is used [15]. The pertinent components of mixed-precision<sup>24</sup>:

- A half-precision ( $p = 2$  bytes) copy of the model is used to perform the forwards and backwards passes
- A second, "master copy" of the model is also kept with weights in full  $p = 4$  precision
- The internal `Adam` states are kept in full-precision

Confusingly, the master copy weights are usually accounted for as part of the optimizer state, in which case the above is altered to

$$M_{\text{model}} = 2N_{\text{params}}, \quad M_{\text{optim}} = 12N_{\text{params}} \quad (\text{mixed precision}). \quad (3.3)$$

The optimizer state is now six times the cost of the actual model used to process data and the costs of (3.3) are more than those of (3.1). However, as we will see, the reduced cost of activation memory can offset these increased costs, and we get the added benefit of increased speed due to specialized hardware. The above also demonstrates why training is so much more expensive than inference.

---

<sup>23</sup>So, in the usual  $E = 4$  case, the `MLP` layers are twice as costly as the `CausalAttention` layers.

<sup>24</sup>A note on the implementation of mixed-precision in `torch`: usually mixed-precision occurs by wrapping the forward pass in a context manager, `torch.autocast`. The default behavior is to then create copies of some tensors in lower-precision and do the forward pass with those. For instance, this is done with matrix-multiplies whose arguments and outputs will be in `torch.float16`, but for sums the inputs and outputs will all be `torch.float32`, for vanilla mixed-precision usage. Consequently, any such `torch.float16` versions of tensor will often persist effectively as contributors to activation memory, since the backwards pass will need those same tensors. This can be verified by inspecting the saved tensors: if `z` is the output of a matrix-multiply in such an autocast context, `z.grad_fn._saved_mat2` will be a `torch.float16` copy of the weights used to perform the matrix-multiply. In effect, the cost of the model weights which are used for the actual forward pass are only materialized within the lifetime of the context manager.

### 3.1.2 Gradients

Gradients are pretty simple and always cost the same regardless of whether or not mixed-precision is used:

$$M_{\text{grad}} = 4N_{\text{params}} . \quad (3.4)$$

In mixed precision, even though the gradients are initially computed in  $p = 2$ , they have to be converted to  $p = 4$  to be applied to the master weights of the same precision.

### 3.1.3 Activations

Activations will require a more extended analysis [5]. Unlike the above results, the activation memory will depend on both the batch size and input sequence length,  $B$  and  $S$ , scaling linearly with both.

**Attention Activations** We will count the number of input elements which need to be cached. Our  $(B, S, D)$ -shaped inputs to the attention layer with  $BDS$  elements are first converted to  $3BDS$  total query, key, value elements, and the query-key dot products produce  $ABS^2$  more, which are softmaxed into  $ABS^2$  normalized scores. The re-weighted inputs to the final linear layer also have  $BDS$  elements, bringing the running sum to  $BS(5D + 2AS)$

Finally, there are also the dropout layers applied to the normalized attention scores and the final output whose masks must be cached in order to backpropagate. In torch, the mask is a `torch.bool` tensor, but surprisingly these use one *byte* of memory per element, rather than one bit <sup>25</sup>. Given this, the total memory cost from activations is

$$M_{\text{act}}^{\text{Attention}} = BLS((5p + 1)D + (2p + 1)AS) . \quad (3.5)$$

**MLP Activations** First we pass the  $(B, S, D)$ -shaped inputs into the first MLP layer. These turn into the  $(B, S, E*D)$  inputs of the non-linearity, whose same-shaped outputs are then passed into the last `Linear` layer, summing to  $(2E+1)BDS$  total elements thus far. Adding in the dropout mask, the total memory requirement across all MLP layers is:

$$M_{\text{act}}^{\text{MLP}} = (2Ep + p + 1)BDLS . \quad (3.6)$$

**LayerNorm, Residual Connections, and Other Contributions** Then the last remaining components. The `LayerNorm` instances each have  $BDS$  inputs and there are two per transformer block, so  $M_{\text{act}}^{\text{LayerNorm}} = 2pBDLS$ , and there is an additional instance at the end of the architecture<sup>26</sup>. There are two residual connections per block, but their inputs do not require caching (since their derivatives are independent of inputs). Then, there are additional contributions from pushing the last layer’s outputs through the language-model head and computing the loss function, but these do not scale with  $L$  and are ultimately  $\sim \mathcal{O}(\frac{V}{DL})$  suppressed, so we neglect them.

---

<sup>25</sup>As you can verify via `4 * torch.tensor([True]).element_size() == torch.tensor([1.]).element_size()`.

<sup>26</sup>Following [5] we will neglect this in the below sum, an  $\mathcal{O}(1/L)$  error

**Total Activation Memory** Summing up the contributions above, the total activation memory cost per-layer is

$$M_{\text{act}}^{\text{total}} \approx 2BDLS \left( p(E+4) + 1 + \mathcal{O} \left( \frac{V}{DL} \right) \right) + ABLS^2 (2p+1) . \quad (3.7)$$

Evaluating in common limits, we have:

$$\begin{aligned} M_{\text{act}}^{\text{total}} \Big|_{E=4,p=4} &= BLS (66D + 15AS) \\ M_{\text{act}}^{\text{total}} \Big|_{E=4,p=2} &= BLS (34D + 5AS) \end{aligned} \quad (3.8)$$

**When does mixed-precision reduce memory?** (Answer: usually.) We saw in Sec. 3.1.1 that mixed precision *increases* the fixed costs of non-activation memory, but from the above we also see that it also *reduces* the activation memory and the saving increase with larger batch sizes and sequence lengths. It is straightforward to find where the tipping point is. Specializing to the case  $E = 4$ , vanilla mixed-precision case with no parallelism<sup>27</sup>, the minimum batch size which leads to memory savings is

$$B_{\min} = \frac{6D^2}{8DS + AS^2} . \quad (3.9)$$

Plugging in numbers for the typical  $\mathcal{O}(40 \text{ GiB})$  model in the Summer of 2023 gives  $B_{\min} \sim \mathcal{O}(1)$ , so mixed-precision is indeed an overall savings at such typical scales.

---

<sup>27</sup>With both tensor- and sequence-parallelism, the parallelism degree  $T$  actually drops out in the comparison (since both form of memory are decrease by  $1/T$ , so this restriction can be lifted).

### Side Note: Optimizations

The above analysis is conservative and accounts for more tensors than are actually saved in practice.

For instance, both the input and outputs of all non-linearities were counted, but there are many activations whose derivatives can be reconstructed from its outputs alone:  $\phi'(z) = F(\phi(z))$  for some  $F$ . Examples:

- **ReLU**: since  $\phi(z) = z\theta(z)$ , then (defining the derivative at zero to be zero)  $\phi'(z) = \theta(z) = \theta(\phi(z))$ . Correspondingly, torch only uses the ReLU outputs [to compute the derivative](#) (there is no self arg in the `threshold_backward(grad, result, 0)` line).
- **tanh**: since  $\tanh'(z) = 1 - \tanh(z)^2$ .

Other cases do not have this nice property, in which case both the inputs and outputs need to be stored:

- **GeLU [16]**:  $\phi(z) = z\Phi(z)$  here and the derivative  $\phi'(z) = \Phi(z) + \frac{ze^{-z^2/2}}{\sqrt{2\pi}}$ , both the inputs and outputs [must be used in the backwards pass..](#)

The explicit CUDA kernel [is here](#).

If the inputs in each of these cases are not needed for any other part of the backwards pass, they are garbage collected in `torch` soon after creation.

**Example** : **Softmax** is another instance where this occurs, since

$$\partial_i \text{Softmax}(x_j) = \delta_{ij} \text{Softmax}(x_j) - \text{Softmax}(x_i) \text{Softmax}(x_j) \quad (3.10)$$

Because of this, the actual amount of activation memory due to the attention layer after the forwards pass is (3.5) with  $2p \rightarrow p$  in the  $\mathcal{O}(S^2)$  term, though the above expression better reflects the necessary peak memory.

## 3.2 Tensor Parallelism

In **Tensor Parallelism**, sometimes also called **Model Parallelism**, individual weight matrices are split across devices [17]. We consider the **MLP** and **CausalAttention** layers in turn. Assume  $T$ -way parallelism such that we split some hidden dimension into  $T$ -equal parts across  $T$  workers<sup>28</sup>

---

<sup>28</sup>All  $T$  workers work on processing the same batch collectively. With  $N > T$  workers, with  $N$  perfectly divisible by  $T$ , there are  $N/T$  different data parallel groups. Critical-path TP communications occur within each data parallel group and gradients are synced across groups. Ideally, all the workers in a group reside on the same node, hence the usual  $T = 8$ .

## Essentials

The cost of large weights can be amortized by first sharding its output dimension, resulting in differing activations across group members. Later, the activations are brought back in sync via a `AllReduce`. Weights which act on the sharded-activations can also be sharded in their input dimension. In the backwards pass, another `AllReduce` is required.

**MLP** It is straightforward to find the reasonable ways in which the weights can be partitioned. We suppress all indices apart from those of the hidden dimension for clarity.

The first matrix multiply  $z_d W_{de}^0$  is naturally partitioned across the output index, which spans the expanded hidden dimension  $e \in \{0, \dots, ED - 1\}$ . This functions by splitting the weight matrix across its output indices across  $T$  devices:  $W_{de}^0 = W_{d(ft)}^0 \equiv \bar{W}_{df\bar{t}}^0$  (again in `einops`-like notation, with bars denoting that the tensor and particular indices are sharded; see App. A), where in the split weights  $\bar{t} \in \{0, \dots, T - 1\}$ , and  $f \in \{0, \dots, \frac{ED}{T} - 1\}$ . Each of the  $T$  workers compute one shard of  $z_d \bar{W}_{df\bar{t}}^0$ , i.e. each has a different value of  $\bar{t}$ . Note that each worker will have to store all components of the input  $z$  for their backward pass, and in the backwards pass `AllReduce` operation (see App. B) will be needed<sup>29</sup> to collect gradient shards from other.

Let the partial outputs from the previous step be  $\bar{z}_{ft}$  (batch-index suppressed), which are (B, S, E\*D/T, T)-shaped, with the final dimension sharded across workers. The non-linearity  $\phi$  acts element wise, and using the updated  $\bar{z}_{f\bar{t}}$  to compute the second matrix multiply requires a splitting the weights as in  $W_{ed'}^1 = W_{(ft)d'}^1 \equiv \bar{W}_{f\bar{t}d'}^1$  (dividing up the incoming  $e$  dimension), such that the desired output is computed as in  $\bar{z}_{f\bar{t}} \cdot \bar{W}_{f\bar{t}d'}^1$ , sum over  $\bar{t}$  implied. Each device has only  $\bar{t}$  component in the sum (a (B, S, D)-shaped tensor) and an `AllReduce` is used to give all workers the final result. This `AllReduce` is the only forward-pass collective communication<sup>30</sup>.

One-line summary of the parallel decomposition:

$$z_{sd'} \leftarrow \phi(z_d W_{de}^0) W_{ed'}^1 = \phi(z_d \bar{W}_{df\bar{t}}^0) \bar{W}_{f\bar{t}d'}^1 . \quad (3.11)$$

The progression of tensor shapes held by any single worker is

1. (B, S, D)
2. (B, S, E\*D/T)
3. (B, S, D)

**Attention** Because the individual attention head computations are independent, they can be partitioned across  $T$  workers without collectively communications. An `AllReduce` is needed for the final projection, however, which results in the various re-weighted values  $y_{bsea}$  (1.5).

To review, the attention outputs  $z'_{sd}$  generated from inputs  $z_{sd}$  can be expressed as

$$z'_{sea} = \text{MHA}(q_{sea}, k_{sea}, v_{sea}) O_{ead}$$

<sup>29</sup>This is needed at the step of computing the gradient with respect to the original input. At the first step of a sharded computation when each worker gets  $y_{e\bar{t}} = z_d W_{de\bar{t}}$  for one value of  $\bar{t}$ , they only hold this slice of the activations. In order to compute  $\partial_{z_d} \mathcal{L} = \partial_{y_{e\bar{t}}} \mathcal{L} \times W_{de\bar{t}}$ , the  $\bar{t}$  sum must be computed by an `AllReduce`

<sup>30</sup>The amount of communicated data is  $\mathcal{O}(BSD)$ .



**Figure 2.** Tensor parallelism for the MLP layers. Graphic from [17]. The  $f/g$  operations are the collective identity/`AllReduce` operations in the forwards pass and the `AllReduce`/identity operations in the backwards pass.

(3.12)

where:

- We have split the  $d$ -index as in  $z_{sd} \rightarrow z_{s(ea)}$  with  $e$  and  $a$  the head-dimension and head-index
- $q_{sea}, k_{sea}, v_{sea}$  are the query, keys and values derived from the inputs
- MHA is the multi-head attention function, whose outputs are the same shape as its value inputs
- The dual sum over head-dimension index ( $e$ ) and attention-head-index ( $a$ ) is the sum-and-concatenate step from the more explicit description in Sec. 1.3
- **Dropout** and biases were ignored for simplicity

In order to parallelize the above  $T$ -ways, we simply shard across the dimension  $a$  which indexes the different attention heads. The MHA computations all process in embarrassingly-parallel fashion, and an all-reduce is needed to complete the sum over the  $a$ -index across devices.

The collective communications story is essentially equivalent to that of the MLP layers<sup>31</sup>: one `AllReduce` is needed in the forwards pass and one `AllReduce` in the backwards-pass.

The progression of tensor shapes held by any single worker is

1.  $(B, S, D)$
2.  $(B, S, D/A, A/T)$

---

<sup>31</sup>The amount of communicated data is again  $\mathcal{O}(BSD)$ .

### 3. (B, S, D)

It is worth comparing the communications and FLOPs costs of these sharded layers. Each layer costs  $\mathcal{O}(BS(4+2E)D^2/T)$  FLOPs and communicates  $\mathcal{O}(BSD)$  bytes and so the communication-to-compute-time ratio is

$$\frac{t_{\text{compute}}}{t_{\text{comms}}} \sim \frac{(4+2E)D}{T} \times \frac{\lambda_{\text{comms}}}{\lambda_{\text{FLOP/s}}} . \quad (3.13)$$

Since<sup>32</sup>  $\frac{\lambda_{\text{comms}}}{\lambda_{\text{FLOP/s}}} \sim 10^{-3}$  FLOPs/B, communication and compute take similar times when  $D \sim \mathcal{O}(10^3)$  for typical setups with  $T \sim \mathcal{O}(10)$  and so tensor-parallelism requires  $D \gtrsim 10^4$  to reach similar efficiency to the non-tensor-parallel implementations.



**Figure 3.** Tensor parallelism for the `CausalAttention` layers. Graphic from [17]. The  $f/g$  operators play the same role as in Fig. 2.

**Embedding and LM Head** Last, we can apply tensor parallelism to the language model head, which will also necessitate sharding the embedding layer, if the two share weights, as typical.

For the LM head, we shard the output dimension as should be now familiar, ending up with  $T$  different (B, S, V/T)-shaped tensors, one per group member. Rather than communicating these large tensors around and then computing the cross-entropy loss, it is more efficient to have each worker compute their own loss where possible and then communicate the scalar losses around<sup>33</sup>.

<sup>32</sup>Assuming  $\lambda_{\text{FLOP/s}} \sim 100$  TFLOP/s and  $\lambda_{\text{comms}} \sim 100$  GiB/s.

<sup>33</sup>In more detail, given the gold-answers  $y_{bs}$  for the next-token-targets, a given worker can compute their contribution to the loss whenever their (B, S, V/T)-shaped output  $z_{bsv'}$  contains the vocabulary dimension  $v_*$  specified by  $y_{bs}$ , otherwise those tensor components are ignored.

For a weight-tied embedding layer, the former construction requires `AllReduce` in order for every worker to get the full continuous representation of the input.

**LayerNorm and Dropout** `LayerNorm` instances are not sharded in pure tensor parallelism both because there is less gain in sharding them parameter-wise, but also sharding `LayerNorm` in particular would require additional cross-worker communication, which we wish to reduce as much as possible. `Dropout` layers are also not sharded in where possible in pure tensor parallelism, but sharding the post-attention `Dropout` layer is unavoidable. It is the goal of sequence parallelism is to shard these layers efficiently; see Sec. 3.3.

**Effects on Memory** The per-worker memory savings come from the sharding of the weights and the reduced activation memory from sharded intermediate representations.

The gradient and optimizer state memory cost is proportional to the number of parameters local to each worker (later we will also consider sharding these components to reduce redundantly-held information). The number of parameters per worker is reduced to

$$N_{\text{params}} \approx (4 + 2E) \frac{LD^2}{T} , \quad (3.14)$$

counting only the dominant contribution from weights which scale with  $L$ , since every weight is sharded. Since all non-activation contributions to training memory scale with  $N_{\text{params}}$ , this is a pure  $1/T$  improvement.

The per-layer activation memory costs (3.5) and (3.6) are altered to:

$$\begin{aligned} M_{\text{act}}^{\text{Attention}} &= BS \left( \left( p + \frac{4p}{T} + 1 \right) D + \left( \frac{2p+1}{T} \right) AS \right) \\ M_{\text{act}}^{\text{MLP}} &= \left( \frac{2Ep}{T} + p + 1 \right) BDS . \end{aligned} \quad (3.15)$$

The derivation is similar to before. Adding in the (unchanged) contributions from `LayerNorm` instances, the total, leading order activation memory sums to

$$M_{\text{act}}^{\text{total}} \approx 2BDLS \left( p \left( 2 + \frac{E+2}{T} \right) + 1 \right) + ABLS^2 \left( \frac{2p+1}{T} \right) . \quad (3.16)$$

Again, the terms which did not receive the  $1/T$  enhancement correspond to activations from unsharded `LayerNorm` and `Dropout` instances and the  $1/T$ 's improvements can be enacted by layering sequence parallelism on top (Sec. 3.3).

### 3.3 Sequence Parallelism

In (3.16), not every factor is reduced by  $T$ . **Sequence Parallelism** fixes that by noting that the remaining contributions, which essentially come from `Dropout` and `LayerNorm`<sup>34</sup>, can be parallelized in the sequence dimension (as can the residual connections).

The collective communications change a bit. If we shard the tensors across the sequence dimension before the first `LayerNorm`, then we want the following:

---

<sup>34</sup>Recall, though, from Sec. 1.2 that the parameters in `LayerNorm` are completely redundant and can simply be removed without having any effect on the expressive capabilities of the architecture.

1. The sequence dimension must be restored for the **CausalAttention** layer
2. The sequence should be re-split along the sequence dimension for the next **LayerNorm** instance
3. The sequence dimension should be restored for the **MLP** layer <sup>35</sup>

The easiest way to achieve the above is the following.

1. If the tensor parallelization degree is  $T$ , we also use sequence parallelization degree  $T$ .
2. The outputs of the first **LayerNorm** are **AllGather**-ed to form the full-dimension inputs to the **CausalAttention** layer
3. The tensor-parallel **CausalAttention** layer functions much like in Fig. 3 *except* that we do not re-form the outputs to full-dimensionality. Instead, before the **Dropout** layer, we **ReduceScatter** them from being hidden-sharded to sequence-sharded and pass them through the subsequent **Dropout/LayerNorm** combination, similar to the first step
4. The now-sequence-sharded tensors are reformed with another **AllGather** to be the full-dimensionality inputs to the **MLP** layer whose final outputs are similarly **ReduceScatter**-ed to be sequence-sharded and are recombined with the residual stream

The above allows the **Dropout** mask and **LayerNorm** weights to be sharded  $T$ -ways, but if we save the full inputs to the **CausalAttention** and **MLP** layers for the backwards pass, their contributions to the activation memory are not reduced (the  $p$ -dependent terms in (3.15)). In [5], they solve this by only saving a  $1/T$  shard of these inputs on each device during the forward pass and then performing an extra **AllGather** when needed during the backwards pass. Schematics can be seen in Fig. 4 and Fig. 5 below. The activation memory is then reduced to:

$$M_{\text{act}}^{\text{total}} = \frac{2BDLS(p(E+4)+1)}{T} + \frac{ABLS^2(2p+1)}{T} + \mathcal{O}(BSV) . \quad (3.17)$$

In more detail:

- The norms are just linear operations on the  $z_{sd}$ ,  $z'_{sd} = \text{Norm}(z_{sd})$ , and so we split and shard cross the sequence dimension  $z_{sd} \rightarrow z_{(tr)d} \equiv \bar{z}_{\bar{t}rd}$  with the TP-index  $t$  sharded across devices.
- The residual stream is also sharded across the sequence dimension.
- The sharded outputs  $\bar{z}_{\bar{t}rd}$  must be re-formed to create the attention and MLP inputs via an **AllGather**. There is an optimization choice here: either the re-formed tensors can be saved for the backward pass (negating the  $1/T$  memory savings) or they can be re-formed via an **AllGather**, at the cost of extra communication.
- Both the MLP and attention layers need to produce final sums of the form  $\bar{y}_{s\bar{y}e}\bar{O}_{\bar{t}ed}$  for some intermediate  $\bar{y}$  and weight  $\bar{O}$  sharded across the TP-dimension  $\bar{t}$ . The outputs are added to the sequence-sharded residual stream, and so sum is optimally computed through an **ReduceScatter** with final shape  $\bar{z}_{\bar{t}'rd} = z_{(t'r)d} = z_{sd} = \bar{y}_{s\bar{t}e}\bar{O}_{\bar{t}ed}$ . This **ReduceScatter** (along with the **AllGather** mentioned above) replace the **AllReduces** from the tensor-parallel case and have the overall communication cost.

---

<sup>35</sup>This doesn't seem like a hard-requirement, but it's what is done in [5].



**Figure 4.** Interleaved sequence and tensor parallel sections.  $g$  and  $\bar{g}$  are AllGather and ReduceScatter in the forward pass, respectively, and swap roles in the backwards pass. Graphic from [17].



**Figure 5.** Detail of the sequence-tensor parallel transition for the MLP . Graphic from [17].

### 3.4 Pipeline Parallelism

TODO

### 3.5 Case Study: Mixed-Precision GPT3

Let's run through the numbers for mixed-precision GPT3 with parameters:

$$L = 96, \quad D = 12288, \quad A = 96, \quad V = 50257. \quad (3.18)$$

We are leaving the sequence-length unspecified, but the block-size (maximum sequence-length) is  $K = 2048$ .

Start by assuming no parallelism at all. The total (not per-layer!) non-activation memory is

$$M_{\text{non-act}}^{\text{GPT-3}} \approx 1463 \text{ TiB} \quad (3.19)$$

which can be broken down further as

$$M_{\text{params}}^{\text{GPT-3}} \approx 162 \text{ TiB}, \quad M_{\text{grads}}^{\text{GPT-3}} \approx 325 \text{ TiB}, \quad M_{\text{optim}}^{\text{GPT-3}} \approx 975 \text{ TiB}. \quad (3.20)$$

The embedding matrix only makes up  $\approx .3\%$  of the total number of parameters, justifying our neglect of its contribution in preceding expressions.

The activation memory is

$$M_{\text{act}}^{\text{GPT-3}} \approx 3 \times 10^{-2} BS \times \left(1 + \frac{S}{10^3}\right) \text{ TiB} . \quad (3.21)$$

Note that the attention matrices, which are responsible for  $\mathcal{O}(S^2)$  term, will provide the dominant contribution to activation memory in the usual  $S \gtrsim 10^3$  regime.

In the limit where we process the max block size ( $S = K = 2048$ ), the ratio of activation to non-activation memory is

$$\frac{M_{\text{act}}^{\text{GPT-3}}}{M_{\text{non-act}}^{\text{GPT-3}}} \Big|_{S=2048} \approx .2B . \quad (3.22)$$

So, the activation memory is very significant for such models.

Using tensor parallelism into the above with the maximal  $T = 8$  which can be practically used, the savings are significant. The total memory is now

$$M_{\text{total}}^{\text{GPT-3}} \approx 187 \text{ TiB} + 10^{-2} BS + 5 \times 10^{-6} BS^2 . \quad (3.23)$$

## 4 Training FLOPs

The total number of floating point operations (FLOPs)<sup>36</sup> needed to process a given batch of data is effectively determined by the number of matrix multiplies needed.

Recall that a dot-product of the form  $v \cdot M$  with  $v \in \mathbb{R}^m$  and  $M \in \mathbb{R}^{m,n}$  requires  $(2m - 1) \times n \approx 2mn$  FLOPs. For large language models,  $m, n \sim \mathcal{O}(10^3)$ , meaning that even expensive element-wise operations like GeLU acting on the same vector  $v$  pale in comparison by FLOPs count<sup>37</sup>. It is then a straightforward exercise in counting to estimate the FLOPs for a given architecture. The input tensor is of shape (B, S, D) throughout.

---

<sup>36</sup>The notation surrounding floating-point operations is very confusing because another quantity of interest is the number of floating-point operations a given implementation can use *per-second*. Sometimes, people use FLOPS or FLOP/s to indicate the rate, rather than the gross-count which has the lower case “s”, FLOPs, but there’s little consistency in general. We will use FLOPs and FLOP/s.

<sup>37</sup>Since their FLOPs counts only scales as  $\sim \mathcal{O}(n)$  where the omitted constant may be relatively large, but still negligible when all dimensions are big.

## Essentials

The number of FLOPs to push a batch of  $B$  of sequence-length  $S$  examples through the forwards-pass of a decoder-only transformer is approximately  $2BSN_{\text{params}}$  where the number of parameters accounts for any reductions due to tensor- and sequence-parallelism<sup>a</sup>. The backwards-pass costs about twice as much as the forwards-pass. This is true as long as  $S \lesssim D$ .

<sup>a</sup>A quick argument: a computation of the form  $T_{a_0 \dots a_n j} = V_{a_0 \dots a_n i} M_{ij}$  requires  $2A_0 \dots A_n IJ$  FLOPs where the capital letters represent the size of their similarly-index dimensions. Thus, the FLOPs essentially count the size of the matrix  $M$  (that is,  $IJ$ ), up to a factor of 2 times all of the dimensions in  $V$  which weren't summed over. Therefore, passing a  $(B, S, D)$ -shaped tensor through the Transformer architecture would give  $\sim 2BS \times (\text{sum of sizes of all weight-matrices})$  FLOPs, and that this last factor is also approximately the number of parameters in the model (since that count is dominated by weights). Thus,  $\text{FLOPs} \approx 2BSN_{\text{params}}$ . This is the correct as long as the self-attention FLOPs with  $\mathcal{O}(S^2)$ -dependence which we didn't account for here are actually negligible (true for  $S \lesssim 10D$ ).

### 4.1 No Recomputation

Start with the case where there is no recomputation activations. These are the **model FLOPs** of [5], as compared to the **hardware FLOPs** which account for gradient checkpointing.

**CausalAttention: Forwards** The FLOPs costs:

- Generating the query, key, and value vectors:  $6BSD^2$
- Attention scores:  $2BDS^2$
- Re-weighting values:  $2BDS^2$
- Final projection:  $2BSD^2$

**MLP: Forwards** Passing  $a$  through the MLP layer, the FLOPs due to the first and second matrix-multiplies are equal, with total matrix-multiply FLOPs  $4BSED^2$ .

**Backwards Pass: Approximate** The usual rule of thumb is to estimate the backwards pass as costing twice the flops as the forwards pass. This estimate comes from just counting the number of  $\mathcal{O}(n^2)$  matrix-multiply-like operations and seeing that for every one matrix multiplication that was needed in the forward pass, we have roughly twice as many similar operations in the backwards pass.

The argument: consider a typical sub-computation in a neural network which is of the form  $z' = \phi(W \cdot z)$  where  $z', a$  are intermediate representations  $z, z'$ ,  $\phi$  is some non-linearity, and where the matrix multiply inside the activation function dominates the forwards-pass FLOPs count, as above. Then, in the backwards pass for this sub-computation, imagine we are handed the upstream derivative  $\partial_{z'} \mathcal{L}$ . In order to complete backpropagation, we need both to compute  $\partial_W \mathcal{L}$  to update  $W$  and also  $\partial_z \mathcal{L}$  to continue backpropagation to the next layer down. Each of these operations will cost about as many FLOPs as the forwards-pass, hence the estimated factor of two (but, as we will see, this is a very rough estimate).

Being more precise, let  $z$  be  $(D_0, \dots, D_n, J)$ -shaped and let  $W$  be  $(I, J)$ -shaped such that it acts on the last index of  $z$ , making  $z' (D_0, \dots, D_n, I)$ -shaped. Denoting  $D = \prod_i D_i$  be the number of elements along the  $D_i$  directions for brevity, the forward-FLOPs cost of the sub-computation is therefore  $2DIJ$ .

Adding indices, the two derivatives we need are

$$\begin{aligned}\frac{\partial \mathcal{L}}{\partial W_{ij}} &= \frac{\partial \mathcal{L}}{\partial z'_{d_0 \dots d_n i}} \phi'((W \cdot z)_{d_0 \dots d_n i}) z_{d_0 \dots d_n j} \\ \frac{\partial \mathcal{L}}{\partial z_{d_0 \dots d_n j}} &= \frac{\partial \mathcal{L}}{\partial z'_{d_0 \dots d_n i}} \phi'((W \cdot z)_{d_0 \dots d_n i}) W_{ij},\end{aligned}\quad (4.1)$$

which have shapes  $(I, J)$  and  $(D_0, \dots, D_n, J)$ , respectively. On the right side,  $z$  and  $W \cdot z$  are cached and the element-wise computation of  $\phi'(W \cdot z)$  has negligible FLOPs count, as discussed above: its contribution is  $\mathcal{O}(1/I)$  suppressed relative to the matrix-multiplies. The FLOPs count is instead dominated by the broadcast-multiplies, sums, and matrix-products.

The two derivatives in (4.1) each have the same first two factors in common, and it takes  $DI$  FLOPs to multiply out these two  $(D_0, \dots, D_n, J)$ -shaped tensors into another result with the same shape. This contribution is again  $\mathcal{O}(1/I)$  suppressed and hence negligible. Multiplying this factor with either  $z_{d_0 \dots d_n i}$  or  $W_{ij}$  and summing over the appropriate indices requires  $2DIJ$  FLOPs for either operation, bringing the total FLOPs to  $4DIJ$ , which is double the FLOPs for this same sub-computation in the forward-direction, hence the rough rule of thumb<sup>38</sup>.

## Backwards Pass: More Precise TODO

**Total Model FLOPs** The grand sum is then<sup>39</sup>:

$$C^{\text{model}} \approx 12BDLS(S + (2 + E)D). \quad (4.2)$$

We can also phrase the FLOPs in terms of the number of parameters (3.14) as

$$C^{\text{model}}|_{T=1} = 6BSN_{\text{params}} \times (1 + \mathcal{O}(S/D)) \quad (4.3)$$

where we took the  $T = 1, D \gg S$  limit for simplicity and we note that  $BS$  is the number of total tokens in the processed batches.

## 5 Training Time

Training is generally compute bound (see App. D) and based on the results of Sec. 4 the quickest one could possibly push a batch of data through the model is

$$t_{\min} = \frac{C^{\text{model}}}{\lambda_{\text{FLOP/s}}}. \quad (5.1)$$

---

<sup>38</sup>Note also that the very first layer does not need to perform the second term in (4.1), since we do not need to backpropagate to the inputs, so the total backwards flops is more precisely  $4DIJ(L - 1) + 2DIJ$ .

<sup>39</sup>With a large vocabulary, the cost of the final language model head matrix multiply can also be significant, but we have omitted its  $L$ -independent,  $2BDSV$  contribution here.

Expanding to the entire training run, then with perfect utilization training will take a time

$$t_{\text{total}} \approx \frac{6N_{\text{params}}N_{\text{tokens}}}{\lambda_{\text{FLOP/s}}} . \quad (5.2)$$

Adjust  $\lambda_{\text{FLOP/s}}$  to the actual achievable FLOP/s in your setup to get a realistic estimate.

How many tokens should a model of size  $N_{\text{params}}$ ? Scaling laws (Sec. 6) provide the best known answer, and the Summer 2023 best-guess is that we optimally have  $N_{\text{tokens}} \approx 20N_{\text{params}}$ . So that the above is

$$t_{\text{total}} \approx \frac{120N_{\text{params}}^2}{\lambda_{\text{FLOP/s}}} , \quad (5.3)$$

leading to quadratic growth in training time.

Note that the above is only correct if we are actually only spending  $C^{\text{model}}$  compute per iteration. This is not correct if we use gradient checkpointing and recomputation, in which case we alternatively spend true compute  $C^{\text{hardware}} > C^{\text{model}}$ , a distinction between **hardware FLOPs** and **model FLOPs**. Two corresponding efficiency measures are **model FLOPs utilization** (MFU) and **hardware FLOPs utilization** (HFU). If our iterations take actual time  $t_{\text{iter}}$ , then these are given by

$$\text{MFU} = \frac{t_{\text{iter}}}{t_{\text{min}}^{\text{model}}} , \quad \text{HFU} = \frac{t_{\text{iter}}}{t_{\text{min}}^{\text{hardware}}} , \quad (5.4)$$

where  $t_{\text{min}}^{\text{model}}$  is (5.1) and  $t_{\text{min}}^{\text{hardware}}$  is similar but using  $C^{\text{hardware}}$ .

## 6 Scaling Laws

Empirically-discovered scaling laws have driven the race towards larger and larger models.

### Essentials

Decoder-only model performance improves predictably as a function of the model size, dataset size, and the total amount of compute. As of Summer 2023, there is little sign of hitting any kind of wall with respect to such scaling improvements.

The central parameters are:

- The number of non-embedding model parameters, as excising embedding params was found to generate cleaner scaling laws. Because our  $N_{\text{params}}$  has already been typically neglecting these parameters, we will just use this symbol in scaling laws and keep the above understanding implicit.<sup>40</sup> [18].
- $C$ : total compute, often in units like PFLOP/s-days  $\sim 10^{20}$  FLOPs
- $N_{\text{tokens}}$ : dataset-size in tokens

---

<sup>40</sup>Presumably, the scaling laws are cleaner with these neglected because these params do not contribute directly to FLOPs, unlike most other parameters.

- $\mathcal{L}$ : cross-entropy loss in nats

The specific form of any given scaling law should also be understood to apply to a pretty narrowly defined training procedure, in which choices like the optimizer, learning-rate scheduler, hyperparameter search budget, vocabulary size, tokenization, etc. are often rigidly set. Changing different components of the training procedure is liable to create different scaling laws (though nice laws of some form are still expected to exist).

## 6.1 Original Scaling Laws

The first scaling-laws were reported in [18]. Their simplest form relates the value of the cross-entropy loss *at convergence* (and in nats),  $\mathcal{L}$ , to the number of non-embedding parameter, dataset size in token, and the amount of compute, *in the limit* where only one of this factors is bottlenecking the model<sup>41</sup>. The laws (in our notation):

- $\mathcal{L}(N_{\text{params}}) \approx (N_{\text{params}}^*/N_{\text{params}})^{\alpha_N}$ , with  $\alpha_N \approx 0.076$  and  $N_{\text{params}}^* \approx 8.8 \times 10^{13}$
- $\mathcal{L}(N_{\text{tokens}}) \approx (N_{\text{tokens}}^*/N_{\text{tokens}})^{\alpha_T}$ , with  $\alpha_T \approx 0.095$  and  $N_{\text{tokens}}^* \approx 5.4 \times 10^{13}$
- $\mathcal{L}(C) \approx (C^*/C)^{\alpha_C}$ , with  $\alpha_C \approx 0.050$  and  $C^* \approx 3.1 \times 10^8$  PFLOP/s-days, where the batch size was assumed to be chosen to be compute optimal per the criteria they outline



**Figure 6.** Original scaling laws from [18].

## 6.2 Chinchilla Scaling Laws

As of Summer 2023, the Chinchilla scaling laws in [19] are the de facto best scaling laws for guiding training. The central difference between [19] and [18] is that in the former they adjust their cosine learning-rate schedule to reflect the amount of planned training, while in the latter they do not<sup>42</sup>.

<sup>41</sup>Unclear to me how you know when this is the case?

<sup>42</sup>The learning-rate schedule consist of a linear warm-up stage from a very small  $\eta$  up to the largest value  $\eta_{\max}$ , after which the cosine bit kicks in:  $\eta(s) = \eta_{\min} + (\eta_{\max} - \eta_{\min}) \times \cos\left(\frac{\pi s}{2s_{\max}}\right)$  with  $s$  the step number. In Fig. A1 of [19] they demonstrate that having the planned  $s_{\max}$  duration of the scheduler be longer than the actual number of training steps is detrimental to training (they do not study the opposite regime), which is effectively what was done in [18]. Probably the more important general point is again that the precise form of these scaling laws depend on details of fairly arbitrary training procedure choices, such as the choice of learning-rate scheduler.



**Figure 7.** From [18]. Larger models are much more sample-efficient (faster).

Several different analyses are performed which all give very similar results. The outputs are the optimal values of  $N_{\text{params}}$ ,  $N_{\text{tokens}}$  given a compute budget  $C$ .

- They fix various buckets of model sizes and train for varying lengths. In their resulting loss-vs-FLOPs plot, they determine the model size which led to the best loss at each given FLOPs value, thereby generating and optimal model size vs compute relation.
- They fix various buckets of FLOPs budget and train models of different sizes with that budget, finding the optimal model size in each case. A line can then be fit to the optimal settings across FLOPs budgets in both the parameter-compute and tokens-compute planes.
- They perform a parametric fit to the loss<sup>43</sup>:

$$\mathcal{L}(N_{\text{params}}, N_{\text{tokens}}) = E + \frac{A}{N_{\text{params}}^\alpha} + \frac{B}{N_{\text{tokens}}^\beta}, \quad (6.1)$$

fit over a large range of parameter and token choices. The best-fit values are:

$$E = 1.69, \quad A = 406.4, \quad B = 410.7, \quad \alpha = 0.34, \quad \beta = 0.28. \quad (6.2)$$

Using  $C \approx 6N_{\text{params}}N_{\text{tokens}}$ , the above can be minimized at fixed compute either for number of parameter or the size of the dataset.

In all cases, the findings are that at optimality  $N_{\text{params}} \sim N_{\text{tokens}} \sim C^{.5}$ : both the parameter and tokens budget should be scaled in equal measure.

---

<sup>43</sup>In [19] they model the scaling of the test loss, while in [18] they use the training loss.

# Part III

## Inference

### 7 Basics and Problems

The essentials of decoder-only inference is that a given input sequence  $x_{bs}$  is turned into a probability distribution  $p_{bsv}$  over the vocabulary for what the next token might be. Text is then generated by sampling from  $p_{bsv}$  in some way, appending that value to  $x_{bs}$  to create a one-token-longer sequence, and then repeating until desired.

There are various problems that naive implementations of the above face:

- Repeated computation from processing the same tokens in the same order repeatedly, at least for some sub-slice of  $x_{bs}$ .
- Inherently sequential computation, rather than parallel
- Sub-optimal sampling strategies. Just choosing the most-probable token at each new step, does not guarantee the most-probable overall sequence, for instance.

### 8 Generation Strategies

A quick tour of generation strategies. A very readable blog post comparing strategies can be found [here](#).

#### 8.1 Greedy

The most obvious generation strategy is to take the final, (B, S, V)-shaped outputs  $z_{bsv}$  and just take the next token to be the most-probable one (for the final position in the sequence): `next_token = z[:, -1].argmax(dim=-1)`. A very minimal `generate` method is as below:

```
6  class DecoderOnlyGreedy(DecoderOnly):
7      def __init__(self, *args, **kwargs):
8          super().__init__(*args, **kwargs)
9
10     def generate(self, inputs, max_length):
11         """
12             Naive, minimal generation method. Assumes inputs are already tokenized. max_length can be
13             longer than the block_size, but only up to block_size tokens can ever be included in the
14             context.
15         """
16         self.eval()
17         outputs = inputs.clone()
18         while outputs.shape[1] < max_length:
19             context = outputs[:, -self.block_size :]
20             last_token_pred_logits = self(context)[:, -1]
21             most_probable_token = last_token_pred_logits.argmax(dim=-1)[:, None]
22             outputs = torch.cat([outputs, most_probable_token], dim=-1)
23         return outputs
```

There are various important, practical considerations which are ignored in the above implementation, including:

- Since we are taking the prediction from the last (-1-indexed) element in each sequence, it is crucial that all padding is *left*-padding, so that these final elements are meaningful.
- Models will signal the end of generation by outputting tokenizer-specific codes, and generation must respect these.

See [the `generate` method from the `transformers` library](#) for more fully-featured code (which, correspondingly, is not always easy to follow).

## 8.2 Simple Sampling: Temperature, Top- $k$ , and Top- $p$

The next-most-obvious strategy is to choose the next token by drawing from the probability distribution defined by the  $z_{bsv}$ . There are various refinements of this idea.

A one-parameter generalization of this strategy introduces a (physics-motivated) **Temperature** which just adjusts the scale of the logits:

```
next_token = torch.multinomial((z[:, -1] / temp).softmax(dim=-1), num_samples=1)
```

assuming  $z$  are the final logits. Larger temperature yields a larger variance in the chosen tokens.

With temperature sampling, there is still a non-zero chance of choosing an extremely improbable token, which is undesirable if you do not trust the tails of the distribution. Two common truncation strategies which guard against this:

- **Top- $k$ :** Only choose from the top- $k$  most-probable examples (re-normalizing the probabilities across those  $k$  samples)
- **Top- $p$ :** Only choose from the top-however-many most-probable examples whose probabilities sum to  $p$  (again re-normalizing probabilities). This is also sometimes called **nucleus sampling**.

## 8.3 Beam Search

Choosing, say, the most-probable next-token at each step is not guaranteed to yield the most probable *sequence* of tokens. So, **Beam Search** explores multiple sequences, using different branching strategies, and the probabilities of the various beam sequences can be compared at the end. Important note: generating the most-probable text is not necessarily equal to the most human-like text [20].

## 9 The Bare Minimum and the kv-Cache

There are two separate stages during generation. First, an original, to-be-continued series of prompts  $x_{bs}$  can be processed in parallel to both generate the first prediction and populate any intermediate values we may want to cache for later. We follow [21] and call this the **prefill** stage. For this procedure, we require the entire  $x_{bs}$  tensor.

In the second, iterative part of generation (the **decode** stage) we have now appended one-or-more tokens to the sequence and we again want the next prediction, i.e.  $z[:, -1, :]$  for the last-layer outputs  $z_{bsd}$ . In this stage, we can avoid re-processing the entire  $x_{bs}$  tensor and get away with only processing the final, newly added token, *if* we are clever and cache old results (and accept a very reasonable approximation).

The important pieces occur in the **CausalAttention** layer, as that's the only location in which the sequence index is not completely parallelized across operations. Referring back to Sec. 1.3, given the input  $z_{bsd}$  of the **CausalAttention** layer, the re-weighted value vectors<sup>44</sup>  $w_{bs's'd}^a v_{bs'f}^a$  are the key objects which determine the next-token-prediction, which only depends on the  $s = -1$  index values. Therefore, we can cut out many steps and minimum requirements are:

- Only the attention weights  $w_{bs's'd}^a$  with  $s = -1$  are needed
- The only query values  $q_{bsd}^a$  needed to get the above are those with  $s = -1$
- Every component of the key and value vectors  $k_{bsd}^a, v_{bsd}^a$  is needed, but because of the causal mask, all components except for the last in the sequence dimension ( $s \neq -1$ ) are the same as they were in the last iteration, up to a shift by one position<sup>45</sup>

So, we are led to the concept of the **kv-cache** in which we cache old key and query vectors for generation. The cache represents a tradeoff: fewer FLOPs are needed for inference, but the memory costs are potentially enormous, since the size of the cache grows with batch size and sequence length:

$$M_{\text{kv-cache}} = 2pBDLS/T , \quad (9.1)$$

in the general case with tensor-parallelism. This can easily be larger than the memory costs of the model parameter:  $M_{\text{params}}^{\text{inference}} \sim pN_{\text{params}} \sim pLD^2$  (dropping  $\mathcal{O}(1)$  factors), so that the cache takes up more memory when  $BS \gtrsim D$ . Also, re-processing an entire input sequence every time also exhibits the usual  $\sim S^2$  FLOPs dependence, but this is reduced to being *linear* in  $S$  when the kv-cache is used.

A very minimal implementation<sup>46</sup> is below:

```

6  class CausalAttentionWithCache(CausalAttention):
7      def __init__(self, *args, **kwargs):
8          super().__init__(*args, **kwargs)
9          self.cached_keys = self.cached_values = None
10
11     def forward(self, inputs, use_cache=True):
12         """Forward method with optional cache. When use_cache == True, the output will have a
13         sequence length of one."""

```

<sup>44</sup>Summed over  $s'$ , but concatenating the different  $a$  values over the  $f$  dimension.

<sup>45</sup>This is where we need to accept a mild approximation, if using a sliding attention window. With an infinite context window, if we add a label  $t$  which indexes the iteration of generation we are on, then we would have that  $z_{bsd}^{(t+1)} = z_{b(s-1)d}^{(t)}$  for every tensor in the network, except for when  $s = -1$ , the last position. The finiteness of the context window makes this statement slightly inaccurate because we can only ever keep  $K$  positions in context and the loss of the early tokens upon sliding the window over will slightly change the values in the residual stream.

<sup>46</sup>Warning: very non-optimized code! Purely pedagogical.

```

14     if not use_cache:
15         return super().forward(inputs)
16     if self.cached_keys is None:
17         # If the cache is not yet initialized, we need all q, k, v values.
18         assert (
19             self.cached_values is None
20         ), "If cached_keys is None, cached_values should be None, too"
21         queries, keys, values = self.get_qkv(inputs)
22     else:
23         # Otherwise, we only need q, k, v values for the last sequence position.
24         queries, new_keys, new_values = self.get_qkv(inputs[:, -1])
25         keys = [torch.cat([ck, nk], dim=1) for ck, nk in zip(self.cached_keys, new_keys)]
26         values = [torch.cat([cv, nv], dim=1) for cv, nv in zip(self.cached_values, new_values)]
27     # Update or initialize the cache.
28     self.cached_keys = [k[:, -self.block_size + 1 :] for k in keys]
29     self.cached_values = [v[:, -self.block_size + 1 :] for v in values]
30     last_queries = [q[:, -1] for q in queries]
31     attn_maps = self.get_attn_maps(last_queries, keys)
32     weighted_values = torch.cat(
33         [self.attn_dropout(a) @ v for a, v in zip(attn_maps, values)], dim=-1
34     )
35     z = self.O(weighted_values)
36     z = self.out_dropout(z)
37     return z

```

## 10 Basic Memory, FLOPs, Communication, and Latency

The essentials of inference-time math, much of it based on [22].

**Naive Inference** Processing a single ( $B$ ,  $S$ ,  $D$ )-shaped tensor to generate a single next input costs the  $2BSN_{\text{params}}$  FLOPs we found for the forwards-pass in Sec. 4. Memory costs just come from the parameters themselves:  $M_{\text{infer}}^{\text{naive}} = pN_{\text{params}}$ . Per the analysis of App. D, naive inference is generally compute-bound and so the per-token-latency is approximately<sup>47</sup>  $2BSN_{\text{params}}/\lambda_{\text{FLOP/s}}$  where the FLOPs bandwidth in the denominator is again defined in App. D.

**kv-Cache Inference** The FLOPs requirements for the hidden-dimension matrix multiplies during generation are  $2BN_{\text{params}}$ , since we are only processing a single token, per previous results. This is in addition to the up-front cost of  $2BSN_{\text{params}}$  for the prefill. But, the memory requirements are raised to

$$M_{\text{infer}}^{\text{kv-cache}} = pN_{\text{params}} + 2pBDLS/T . \quad (10.1)$$

Inference now has a computational-intensity of

$$\frac{C_{\text{infer}}^{\text{kv-cache}}}{M_{\text{infer}}^{\text{kv-cache}}} \sim \frac{BD}{S} , \quad (10.2)$$

dropping  $\mathcal{O}(1)$  factors, is now memory-bound (again, see App. D), and has per-token-latency of approximately  $M_{\text{infer}}/\lambda_{\text{mem}}$ , unless the batch-size is very large.

---

<sup>47</sup> Assuming we do the naive thing here and generate the next token in a similarly naive way, shifting over the context window.

**Intra-Node Communication** For  $T$ -way tensor parallelism, two `AllReduces` are needed, one for each MLP and each `CausalAttention` layer, where each accelerator is sending  $pBDS$  bytes of data (see Sec. 3.2). This requires a total of  $4(T - 1)pBDS/T \approx 4pBDS$  bytes to be transferred between workers in the tensor-parallel group (see Foot. 56), taking a total of  $\sim 4pBDLS/\lambda_{\text{comms}}$  time for the model as a whole. For an A100 80GiB, `torch.float16` setup, this is  $\sim BDS \times 10^{-11}$  sec

## Latency

### 11 Case Study: Falcon-40B

Let’s work through the details of the kv-cache for Falcon-40B<sup>48</sup> with  $D = 8192$ ,  $L = 60$ ,  $S = 2048$ . In half,  $p = 2$  precision, the model weights just about fit on an 80GiB A100, but this leaves no room for the cache, so we parallelize  $T$  ways across  $T$  GPUs, assumed to be on the same node. The total memory costs are then

$$M_{\text{total}} \approx \frac{80\text{GiB} + 4\text{GiB} \times B}{T}. \quad (11.1)$$

This means that in order to hit the compute-bound threshold of  $B \sim 200$  (see App. D) we need at least  $T = 4$  way parallelism. Taking  $T = 4$ , and running at capacity with  $B \sim 200$  so that we are compute-bound, the per-iteration latency from computation alone is approximately  $\frac{2BN_{\text{params}}}{\lambda_{\text{FLOP/s}} T} \sim 13\text{ms}$ , i.e. we can give  $\sim 200$  customers about  $\sim 75$  tokens-per-second at this rate<sup>49</sup>, if this were the only latency consideration.

## A Conventions and Notation

We loosely follow the conventions of [5]. Common parameters:

- $A$ : number of attention heads
- $B$ : microbatch size
- $C$ : compute (FLOPs)
- $D$ : the hidden dimension size
- $E$ : expansion factor for MLP layer (usually  $E = 4$ )
- $H$ :  $D/A$ , the head dimension size
- $K$ : the block size (maximum sequence length<sup>50</sup>)
- $L$ : number of transformer layers
- $N_{\text{params}}$ : total number of model parameters

---

<sup>48</sup>Falcon actually uses multi-query attention, which changes the computations here, but we will pretend it does not in this section for simplicity.

<sup>49</sup>Average human reading speed is about  $\sim 185$  words/minute, or  $\sim 4$  tokens/sec.

<sup>50</sup>In the absence of methods such as ALiBi [23] can be used to extend the sequence length at inference time.

- $P$ : pipeline parallel size
- $S$ : input sequence length
- $T$ : tensor parallel size
- $V$ : vocabulary size
- $t$ : various timescales
- $p$ : the precision of the elements of a tensor in bytes
- $\lambda$ : various rates, e.g.  $\lambda_{\text{mem}}$  is memory bandwidth

Where it makes sense, we try to use the lower-case versions of these characters to denote the corresponding indices on various tensors. For instance, an input tensor with the above batch size, sequence length, and vocabulary size would be written as  $x_{bsv}$ , with  $b \in \{0, \dots, B - 1\}$ ,  $s \in \{0, \dots, S - 1\}$ , and  $v \in \{0, \dots, V - 1\}$  in math notation, or as `x[b, s, v]` in code. Typical transformers belong to the regime

$$V \gg D, S \gg L, A \gg P, T . \quad (\text{A.1})$$

For instance, GPT-2 and GPT-3 [2, 3] have  $V \sim \mathcal{O}(10^4)$ ,  $S, L \sim \mathcal{O}(10^3)$ ,  $L, A \sim \mathcal{O}(10^2)$ . We will often assume also assume that<sup>51</sup>  $S \lesssim D$  or the weaker<sup>52</sup>  $BS \lesssim D$ .

As indicated above, we use zero-indexing. We also use `python` code throughout<sup>53</sup> and write all ML code using standard `torch` syntax. To avoid needing to come up with new symbols in math expressions we will often use expressions like  $x \leftarrow f(x)$  to refer to performing a computation on some argument ( $x$ ) and assigning the result right back to the variable  $x$  again.

Physicists often joke (half-seriously) that Einstein's greatest contribution to physics was his summation notation in which index-sums are implied by the presence of repeated indices and summation symbols are entirely omitted. For instance, the dot product between two vectors would be written as

$$\vec{x} \cdot \vec{y} = \sum_i x_i y_i \equiv x_i y_i \quad (\text{A.2})$$

We use similar notation which is further adapted to the common element-wise deep-learning operations. The general rule is that if a repeated index appears on one side of an equation, but not the other, then a sum is implied, but if the same index appears on both sides, then it's an element-wise operation. The Hadamard-product between two matrices  $A$  and  $B$  is just

$$C_{ij} = A_{ij} B_{ij} . \quad (\text{A.3})$$

---

<sup>51</sup>This condition ensures that the  $\mathcal{O}(S^2)$  FLOPs cost from self-attention is negligible compared to  $\mathcal{O}(D^2)$  contributions from other matrix multiplies. It should be noted that in Summer 2023 we are steadily pushing into the regime where this condition does *not* hold.

<sup>52</sup>This condition ensures that the cost of reading the  $\mathcal{O}(D^2)$  weights is more than the cost of reading in the  $\mathcal{O}(BSD)$  entries of the intermediate representations.

<sup>53</sup>Written in a style conducive to latex, e.g. no type-hints and clarity prioritized over optimization.

Einstein notation also has implementations available for `torch`: see this blog post on `einsum` or the `einops` package.

In particular, we use `einops` notation for concatenation and splitting:  $A_e = A_{(de)} = B_{de}$ <sup>54</sup>. We will sometimes use a bar to indicate tensors which are derived from other tensors through such splitting operations, usually in the context of tensor-sharding where devices only locally hold some shard of the tensor. In this context, only some of the dimensions will be sharded across devices, and we may also put a bar over the corresponding sharded index. For instance, consider a two-dimensional tensor  $M_{ab}$  of shape `M.shape=(A, 2)`: sharding this tensor across two devices across the final index results in a tensor  $\bar{M}_{\bar{a}\bar{b}}$  which is of shape `M_bar.shape=(A, 1)` on each device. As here, we will sometimes use bars to denote indices which are sharded over different devices.

We also put explicit indices on operators such as Softmax to help clarify the relevant dimension, e.g. we would write the softmax operation over the  $b$ -index of some batched tensor  $x_{bvd\dots}$  as

$$s_{bvd\dots} = \frac{e^{x_{bvd\dots}}}{\sum_{v=0}^{V-1} e^{x_{bvd\dots}}} \equiv \text{Softmax}_v x_{bvd\dots}, \quad (\text{A.4})$$

indicating that the sum over the singled-out  $v$ -index is gives unity.

## B Collective Communications

A quick refresher on common distributed communication primitives. Consider  $N$  workers with tensor data  $x^{(n)}$  of some arbitrary shape `x.shape`, which takes up  $M$  bytes of memory, where  $n$  labels the worker and any indices on the data are suppressed. The  $n = 0$  worker is arbitrarily denoted the *chief*. Then, the primitive operations are:

- **Broadcast**: all workers receive the chief's data,  $x^{(0)}$ .
- **Gather**: all workers communicate their data  $x_n$  to the chief, e.g. in a concatenated array  $[x^0, x^1, \dots, x^{N-1}]$ .
- **Reduce**: data is **Gather**-ed to the chief, which then performs some operation (`sum`, `max`, `concatenate`, etc.) producing a new tensor  $x'$  on the chief worker.
- **ReduceScatter**: a reducing operation (e.g. `sum`) is applied to the  $x^{(n)}$  to produce a  $x'$  of the same shape (e.g.  $x' = \sum x^{(n)}$ ) and each worker only receives a  $1/N$  slice (and hence  $M/N$  byte) of the result<sup>55</sup>. A ring implementation sends  $M \times \frac{N-1}{N}$  bytes over each link in the ring.
- **AllGather**: all data  $x^{(n)}$  is communicated to all workers; each worker ends up with the array  $[x^0, x^1, \dots, x^{N-1}]$ . Functionally equivalent to a **Gather** followed by **Broadcast**. A ring implementation sends  $M \times (N - 1)$  bytes over each link in the ring.

---

<sup>54</sup>The indexing is all row-major: if  $A_i$  is  $I$ -dimensional,  $i \in \{0, \dots, I - 1\}$ , then if we split this index as  $A_i = A_{(jk)} \equiv \bar{A}_{jk}$ , then the indices  $j, k$  will range over  $j \in \{0, \dots, J\}$ ,  $k \in \{0, \dots, K\}$  with  $I = J \times K$  and where numerically  $i = j \times K + k$ . More complex cases follow by induction.

<sup>55</sup>Note that **AllGather** and **ReduceScatter** are morally conjugate to each other. In the former, each worker ends up with  $N$  times as much data as they started with, while in **ReduceScatter** they end up with  $1/N$  of their initial data. One is nearly a time-reversed version of the other, which is a way of remembering that they have the same communication cost. They also compose to produce an output of the same initial size, as in **AllReduce**.

- **AllReduce**: all workers receive the same tensor  $x'$  produced by operating on the  $x^{(n)}$  with **sum**, **mean**, etc. Functionally equivalent to a **Reduce** followed by **Broadcast**, or a **ReduceScatter** followed by a **AllGather** (the more efficient choice<sup>56</sup>). In the latter case, the total cost is  $2M \times \frac{N-1}{N}$ , due to **AllReduce**-ing the initial  $M$ -sized data, and then **AllGather**-ing the  $M/N$ -sized reductions.

## C Hardware

Basic information about relevant hardware considerations. Much of the following is from the [NVIDIA docs](#).

### C.1 NVIDIA GPU Architecture

NVIDIA GPUs consist of some amount of relatively-slow off-chip DRAM memory<sup>57</sup>, relatively-fast on-chip SRAM, and a number of **streaming multiprocessors** (SMs) which perform the parallel computations. Inside more-recent GPUs, the SMs carry both “CUDA cores” and “Tensor cores”, where the latter are used for matrix-multiplications and the former for everything else.

A few numbers of primary importance:

- The rate at which data can be transferred from DRAM to SRAM ( $\lambda_{\text{mem}}$ )
- The number of FLOP/s, which is more fundamentally computed by multiplying the number of SMs by the FLOPS/cycle of each SM for the specific operation under consideration (see the NVIDIA docs) by the clock rate:  $N_{\text{SM}} \cdot \lambda_{\text{FLOPs/cycle}} \cdot \lambda_{\text{clock}}$

The terminology and structure of the memory hierarchy is also important to understand. Types of memory, from slowest to fastest:

- **Global** memory is the slow, but plentiful, off-chip DRAM. It is the type of memory typically used as kernel arguments
- **Constant** memory is read only and accessible by all threads in a given block. The size of arrays in constant memory must be known at compile time
- **Local Memory** is similarly slow to global memory, but more plentiful than register memory, and privately to individual threads and is allocated from within a kernel. When registers run out, local memory fills the gap
- **Shared** memory is shared between all threads in a given block. Shared memory is effectively a user-controlled cache. The size of arrays in shared memory must be known at compile time

---

<sup>56</sup>The former strategy scales linearly with the number of worker, while the latter strategy underlies “ring” **AllReduce** which is (nearly) independent of the number of workers: if each worker carries data of size  $D$  which is to be **AllReduced**, a total of  $\frac{2(N-1)D}{N}$  elements need to be passed around. See [this blog post](#) for a nice visualization or [24] for a relevant paper.

<sup>57</sup>This is the number usually reported when discussing a given GPU, e.g. 32GiB for the top-of-the-line A100

- **Registers** hold scalar values and small tensors whose values are known at compile time. They are local to each thread and they are plentiful since each thread needs its own set of registers:  $65,536 = 2^{16}$  registers per SM an A100.

An excellent video overview of CUDA and NVIDIA GPU architecture which covers some of the above is here.

## C.2 CUDA Programming Model

The CUDA programming model uses a hierarchy of concepts:

- **Threads** are the fundamental unit of execution<sup>58</sup> which each run the same CUDA **Kernel**, or function, on different data inputs in parallel. Threads within the same block (below) may share resources, like memory, and may communicate with each other. Individual threads are indexed through the **threadIdx** variable, which has `threadIdx.{x, y, z}` attributes with `threadIdx.x` in  $0, \dots, \text{blockDim.x} - 1$  and similar.
- Threads (and hence warps) are organized into 3D **blocks**. The size and indices of the blocks can be accessed through the `blockDim` and `blockIdx` variables, respectively, with `blockIdx.x` in  $0, \dots, \text{gridDim.x} - 1$ . `blockDim.x * blockDim.y * blockDim.z` total threads run in a block.
- Blocks are organized into 3D **groups**. The size of the grid dimensions can be accessed through the `gridDim` variable, with similar attributes to the above.  
`gridDim.x * gridDim.y * gridDim.z` total blocks run in a grid.

The number of threads which can be launched in a given block is hardware limited; A100 80GiB GPUs can run up to 1024 threads in a SM at a time (32 blocks with 32 threads each), for instance. Hence, block and grid sizes need to be adjusted to match the problem size. There are also important memory access considerations here. The 1024 threads which can be launched can also read sequentially from memory and efficient usage implies that choosing the block size such that we are doing these reads as often as possible is ideal.

## C.3 NVIDIA GPU Stats

Summary of some relevant NVIDIA GPU statistics:

| GPU  | Memory | $\lambda_{\text{FLOP/s}}$ | $\lambda_{\text{mem}}$ | $\lambda_{\text{math}}$ | $\lambda_{\text{comms}}$ |
|------|--------|---------------------------|------------------------|-------------------------|--------------------------|
| A100 | 80GiB  | 312 TFLOP/s               | 2.0 TiB/s              | 156 FLOPS/B             | 300 GiB/s                |
| A100 | 40GiB  | 312 TFLOP/s               | 1.6 TiB/s              | 195 FLOPS/B             | 300 GiB/s                |
| V100 | 32GiB  | 130 TFLOP/s               | 1.1 TiB/s              | 118 FLOPS/B             | 16 GiB/s                 |

where

- $\lambda_{\text{FLOP/s}}$  is flops bandwidth (for `float16/bfloat16` multiply-accumulate ops)
- $\lambda_{\text{mem}}$  is memory bandwidth

---

<sup>58</sup>Threads are always physically launched in **Warps** which consist of 32 threads.

- $\lambda_{\text{math}} = \frac{\lambda_{\text{FLOP/s}}}{\lambda_{\text{mem}}}$  is **math bandwidth**
- $\lambda_{\text{comms}}$  is one-way communication bandwidth

A useful approximate conversion rate is that  $1 \text{ TFLOP/s} \approx 100 \text{ PFLOP/day}$ .

Important practical note: the  $\lambda_{\text{FLOP/s}}$  numbers should be taken as aspirational. Out-of-the box, `torch.float16` matrix-multiples in `torch` with well-chosen dimensions tops out around  $\sim 250 \text{ FLOPS/s}$

## D Compute-bound vs Memory-bound

If your matrix-multiples are not sufficiently large on, you are wasting resources [25]. The relevant parameters which determine sufficiency are  $\lambda_{\text{FLOP/s}}$  and  $\lambda_{\text{mem}}$ , the FLOPs and memory bandwidth, respectively. The ratio  $\lambda_{\text{math}} \equiv \frac{\lambda_{\text{FLOP/s}}}{\lambda_{\text{mem}}}$  determines how many FLOPS you must perform for each byte loaded from memory; see App. C.3. If your computations have a FLOPs/B ratio which is larger than  $\lambda_{\text{math}}$ , then you are compute-bound (which is good, as you're maximizing compute), and otherwise you are memory(-bandwidth)-bound (which is bad, since your compute capabilities are idling). The FLOPs/B ratio of your computation is sometimes called the **compute intensity** or **arithmetic intensity**. When compute bound, a process takes time  $\sim F/\lambda_{\text{FLOP/s}}$ , while memory-bound processes take time<sup>59</sup>  $\sim M/\lambda_{\text{mem}}$ .

### D.1 Matrix-Multiplications vs. Element-wise Operations

For instance, to multiply a  $(B, S, D)$ -shaped tensor  $z_{bsd}$  by a  $(D, D)$ -shaped weight-matrix  $W_{dd'}$ ,  $p(BDS + D^2)$  bytes must be transferred from DRAM to SRAM at a rate  $\lambda_{\text{mem}}$ , after which we perform  $2BSD^2$  FLOPs, and write the  $(B, S, D)$ -shaped result back to DRAM again, for a ratio of

$$\frac{1}{p} \frac{BDS}{2BS + D} \text{ (FLOPs/B)} . \quad (\text{D.1})$$

We want to compare this against  $\lambda_{\text{math}}$ , which from App. C.3 we take to be  $\mathcal{O}(100 \text{ FLOPs/B})$ , and plugging in any realistic numbers, shows that such matrix-multiples are essentially always compute-bound. Compare this to the case of some element-wise operation applied to the same  $z_{bsd}$  tensor whose FLOPs requirements are  $\sim C \times BDS$  for some constant-factor  $C \ll S, D$ . Then, then FLOPS-to-bytes ratio is  $\sim \frac{C}{p}$ , which is *always* memory-bound for realistic values of  $C$ . The moral is to try and maximize the number of matrix-multiples and remove as many element-wise operations that you can get away with.

### D.2 Training vs. Inference

Finally, we note that the above has implications for the Transformers architecture as a whole, and in particular it highlights the difficulties in efficient inference. Under the assumptions of Sec. 4,  $\sim \mathcal{O}(BSN_{\text{params}})$  total FLOPs needed during training, while the number of bytes loaded from and

---

<sup>59</sup>Note that the time is not additive, e.g. compute-bound tasks do not take time  $\sim F/\lambda_{\text{FLOP/s}} + M/\lambda_{\text{mem}}$  because they are not sequential: compute and memory-communications can be concurrent.

written to memory are  $\mathcal{O}(BDLS + N_{\text{params}}) \sim \mathcal{O}\left(\frac{BSN_{\text{params}}}{D} + N_{\text{params}}\right)$  which is  $\mathcal{O}(N_{\text{params}})$  for not-super-long sequence lengths. The arithmetic intensity is therefore  $\mathcal{O}(BS)$  and so training is compute-bound in any usual scenario, even at small  $B \sim \mathcal{O}(1)$  batch sizes (as long as individual operations in the network don't suffer from outlandish memory-boundedness). The problem during inference is that (if using the kv-cache; see Sec. 9) we only need to process a *single* token at a time and so  $S \rightarrow 1$  in the numerator in the preceding, while the denominator is also weighed down by the kv-cache in the attention layers.

In more detail, the MLP layers just process  $S = 1$  length tensors during generation, but are insensitive to the kv-cache, so their intensity comes from just setting  $S = 1$  in the above,

$$\sim \frac{BD}{B+D} , \quad (\text{D.2})$$

dropping  $\mathcal{O}(1)$  factors now, while the attention layers have a ratio of the form

$$\sim \frac{BDS + BD^2}{BD + D^2 + BDS} , \quad (\text{D.3})$$

where the last term in the denominator is due to the cache. Now at small  $B \sim \mathcal{O}(1)$  batch sizes, both intensities reduce to  $\mathcal{O}(B)$ , which is insufficient to be compute-bound. In the large  $B \gtrsim D/S$  limit, they at least become  $\mathcal{O}(D)$  and  $\mathcal{O}(1 + \frac{D}{S})$ , respectively, which may be enough to be compute-bound, but it's hard to even get into this regime. Note, the importance of the ratio  $D/S$ . The hidden dimension fixes the context length scale at which inference can never be compute-bound, in the absence of additional tricks not considered here<sup>60</sup>.

### D.3 Intra- and Inter-Node Communication

For intra-node communication, GPUs are connected by either PCIe or NVLink, generally.

- NVLink interconnects are continually updated and achieve speeds of  $\lambda_{\text{comm}}^{\text{intra}} \sim 300 \text{ GiB/s}$ .

For inter-node communication, nodes are often connected by:

- InfiniBand apparently also achieves speeds  $\lambda_{\text{comm}}^{\text{intra}} \sim 100 \text{ GiB/s}$ ? Haven't found a clear reference. But in any case, the bandwidth is divided amongst the GPUs in the node, leading to a reduction by  $\sim 8$ .

## E Batch Size, Compute, and Training Time

The amount of compute directly determines the training time, but not all ways of spending compute are equivalent. We follow the discussion in [26] which gives a rule of thumb for determining the optimal batch size which is sometimes used in practice. The basic point is that all of the optimization steps take the gradient  $\mathbf{g}$  as an input, and since the gradient is the average over

---

<sup>60</sup>One such trick: the multi-query attention of Sec. 2.1 improves everything a factor of  $A$ : the large batch regime is  $B \gtrsim \frac{D}{AS}$  and the intensity ratio becomes  $\mathcal{O}(1 + \frac{D}{AS})$ . An analysis equivalent to the one performed here can be found in the original paper [7].

randomly selected datapoints, steps are more precise as the batch size increases (with diminishing returns, past a certain point, but the computational cost also rises with batch size, and a balance between the two concerns should be struck).

Consider vanilla SGD and study how the training loss changes with each step. We randomly sample  $B$  datapoints  $x \in \mathcal{D}$  from the dataset through some i.i.d. process<sup>61</sup>. Each corresponding gradient  $\mathbf{g}(x) = \partial_w \mathcal{L}(w, x)$  is itself a random variable whose average is the true gradient across the entire dataset  $\bar{\mathbf{g}}$  and we take the variance to be

$$\text{Var}[\mathbf{g}(x), \mathbf{g}(x')] = \Sigma \quad (\text{E.1})$$

for some matrix  $\Sigma$  with (suppressed) indices spanning the space of model weights. Taking instead the mean of a sum of such estimates,  $\mathbf{g}_B \equiv \frac{1}{B} \sum_{x \in \mathcal{B}} \mathbf{g}(x)$ , the mean stays the same, but the variance reduces in the usual way:  $\text{Var}[\mathbf{g}_B(x), \mathbf{g}_B(x')] = \Sigma/B$ .

Study the mean loss across the entire dataset:  $\mathcal{L}(w) = \langle \mathcal{L}(w, x) \rangle$ . Using SGD we take a step  $w \rightarrow w - \eta \mathbf{g}_B$  and change the loss as

$$\mathcal{L}(w - \eta \mathbf{g}_B) = \mathcal{L}(w) - \eta \bar{\mathbf{g}} \cdot \mathbf{g}_B + \frac{1}{2} \mathbf{g}_B \cdot H \cdot \mathbf{g}_B + \mathcal{O}(\mathbf{g}_B^3), \quad (\text{E.2})$$

where  $H$  is the true hessian of the loss over the entire dataset at this value of the weights. Taking the expectation value and minimizing the results over  $\eta$  gives the optimal choice:

$$\eta_\star = \frac{\eta_{\max}}{1 + \frac{B_{\text{noise}}}{B}}, \quad \eta_{\max} \equiv \frac{\bar{\mathbf{g}}^2}{\bar{\mathbf{g}} \cdot H \cdot \bar{\mathbf{g}}}, \quad B_{\text{noise}} \equiv \frac{\text{Tr } H \cdot \Sigma}{\bar{\mathbf{g}} \cdot H \cdot \bar{\mathbf{g}}}. \quad (\text{E.3})$$

Notably, the above supports the usual rule of thumb that the learning rate should be increased proportionally to the batch size, at least whenever  $B \ll B_{\text{noise}}$ . The diminishing returns of pushing batch sizes past  $B_{\text{noise}}$  are also evident. In practice it is too expensive to compute the Hessian, but thankfully the entirely unjustified approximation in which the Hessian is multiple of the identity such that

$$B_{\text{noise}} \approx B_{\text{simple}} \equiv \frac{\text{Tr } \Sigma}{\bar{\mathbf{g}}^2}, \quad (\text{E.4})$$

is somehow a decent approximation empirically, and an estimator can be created for  $B_{\text{noise}}$  in a data-parallel setup; see [26] or [Katherine Crowson's implementation](#) or [neox](#) for more.

We can further characterize the trade-off between compute and optimization steps. The expected decrease in loss per update is then

$$\langle \delta \mathcal{L} \rangle \approx \frac{\eta_{\max}}{1 + \frac{B_{\text{noise}}}{B}} \bar{\mathbf{g}}^2 + \mathcal{O}(\eta_{\max}^2), \quad (\text{E.5})$$

that is, we would need  $1 + \frac{B_{\text{noise}}}{B}$  times as many SGD steps to make the same progress we would have as compared to full-batch SGD. If  $S_{\min}$  is the number of steps that would have been needed for full-batch SGD, we would need  $S = S_{\min} + S_{\min} \frac{B_{\text{noise}}}{B}$  steps for minibatch SGD. The total number of examples seen is correspondingly  $E = S_{\min} \times (B_{\text{noise}} + B) \equiv E_{\min} + S_{\min} B$ , and so we see the

---

<sup>61</sup>The below uses sampling with replacement, while in practice we sample without replacement, but the different is negligible for all practical cases.

trade-off between SGD steps  $S$  and compute  $E$  alluded to above. These relations can be written as<sup>62</sup>

$$\left(\frac{S}{S_{\min}} - 1\right) \left(\frac{E}{E_{\min}} - 1\right) = 1 \quad (\text{E.6})$$

which represent hyperbolic Pareto frontier curves. So, solutions are of the form  $S = (\alpha + 1) S_{\min}$ ,  $E = (\frac{1}{\alpha} + 1) E_{\min}$  and since  $E = BS$  the corresponding batch size is  $B_{\text{crit}} \equiv \frac{1}{\alpha} B_{\text{noise}}$ . The parameter  $\alpha$  characterizes how much you value the trade-off between these two factors and a reasonable balance is the  $\alpha = 1$  solution for which  $S = 2S_{\min}$ ,  $E = 2E_{\min}$  and  $B_{\text{crit}} = B_{\text{noise}}$  exactly.

Correspondingly, in [26] they suggest training at precisely this batch size. But it seems much more relevant to balance time against compute directly, rather than optimization steps vs compute. Modeling the total training time by  $T \approx S(\kappa B + \sigma)$  for some  $\kappa, \sigma$  to model compute costs<sup>63</sup>, then the above is equivalent to

$$T = \frac{(E_{\min} + S_{\min}B)(\kappa B + \sigma)}{B} . \quad (\text{E.7})$$

which has a minimum at

$$B = \sqrt{\frac{\sigma E_{\min}}{\kappa S_{\min}}} . \quad (\text{E.8})$$

for which the total time is

$$T_{\min} = \left( \sqrt{\kappa E_{\min}} - \sqrt{\sigma S_{\min}} \right)^2 . \quad (\text{E.9})$$

In comparison, the total time for the  $B_{\text{crit}} = \frac{E_{\min}}{S_{\min}}$  strategy of [26] gives  $T_{\min} = 2(\kappa E_{\min} + \sigma S_{\min})$  which is a factor of  $\frac{2}{1 - \frac{\sqrt{\sigma \kappa} B_{\text{noise}}}{\kappa B_{\text{noise}} + \sigma}}$  larger. So, this seems like a better choice of optimal batch size, if you value your time.

## F Cheat Sheet

Collecting all of the most fundamental equations, given to various degrees of accuracy.

Number of model parameters:

$$N_{\text{params}} = (4 + 2E)L D^2 + VD + \mathcal{O}(DL) \approx (4 + 2E)L D^2 , \quad (\text{F.1})$$

assuming no sharding of the embedding matrix.

---

<sup>62</sup>The analysis here is simplified in that it assumes that the noise scale and the chosen batch size are both time-independent. There is confusing logic treating the more general case where both  $B_{\text{noise}}$  and  $B$  vary with step in [26], but in any case, the ultimate relations they use are effectively the same.

<sup>63</sup>Computation and communication costs each scale with  $B$ , the optimizer step does not (and maybe some overhead?), for instance.

**Training** Memory costs for mixed-precision training:

$$\begin{aligned} M_{\text{model}} &= p_{\text{model}} N_{\text{params}} \\ M_{\text{optim}} &= (s_{\text{states}} + 1) \times p_{\text{master}} N_{\text{params}} \\ M_{\text{act}}^{\text{total}} &= \frac{2BDLS(p(E+4)+1)}{T} + \frac{ABLS^2(2p+1)}{T} + \mathcal{O}(BSV) \end{aligned} \quad (\text{F.2})$$

where  $s_{\text{states}}$  is the number of optimizer states, e.g.  $s = 0$  for SGD and  $s = 2$  for Adam. FLOPs total:

$$F_{\text{total}}^{\text{model}} \approx 12BDLS(S + (2+E)D). \quad (\text{F.3})$$

## G TODO

- Tokenizers
- Activations
- Flash attention
- BERT family
- Residual stream
- Scaling laws
- Cheat sheet

## References

- [1] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin, “Attention is all you need,” [arXiv:1706.03762 \[cs.CL\]](#). 3, 4
- [2] A. Radford, J. Wu, R. Child, D. Luan, D. Amodei, I. Sutskever, *et al.*, “Language models are unsupervised multitask learners,” *OpenAI blog* **1** (2019) no. 8, 9. 3, 37
- [3] T. B. Brown, B. Mann, N. Ryder, M. Subbiah, J. Kaplan, P. Dhariwal, A. Neelakantan, P. Shyam, G. Sastry, A. Askell, S. Agarwal, A. Herbert-Voss, G. Krueger, T. Henighan, R. Child, A. Ramesh, D. M. Ziegler, J. Wu, C. Winter, C. Hesse, M. Chen, E. Sigler, M. Litwin, S. Gray, B. Chess, J. Clark, C. Berner, S. McCandlish, A. Radford, I. Sutskever, and D. Amodei, “Language models are few-shot learners,” [arXiv:2005.14165 \[cs.CL\]](#). 37
- [4] OpenAI, “Gpt-4 technical report,” [arXiv:2303.08774 \[cs.CL\]](#). 3
- [5] V. Korthikanti, J. Casper, S. Lym, L. McAfee, M. Andersch, M. Shoeybi, and B. Catanzaro, “Reducing activation recomputation in large transformer models,” [arXiv:2205.05198 \[cs.LG\]](#). 3, 17, 24, 27, 36
- [6] R. Xiong, Y. Yang, D. He, K. Zheng, S. Zheng, C. Xing, H. Zhang, Y. Lan, L. Wang, and T.-Y. Liu, “On layer normalization in the transformer architecture,” [arXiv:2002.04745 \[cs.LG\]](#). 4
- [7] N. Shazeer, “Fast transformer decoding: One write-head is all you need,” [arXiv:1911.02150 \[cs.NE\]](#). 5, 11, 42

- [8] G. Yang, E. J. Hu, I. Babuschkin, S. Sidor, X. Liu, D. Farhi, N. Ryder, J. Pachocki, W. Chen, and J. Gao, “Tensor programs v: Tuning large neural networks via zero-shot hyperparameter transfer,” [arXiv:2203.03466 \[cs.LG\]](#). 6
- [9] J. Ainslie, J. Lee-Thorp, M. de Jong, Y. Zemlyanskiy, F. Lebrón, and S. Sanghai, “Gqa: Training generalized multi-query transformer models from multi-head checkpoints,” [arXiv:2305.13245 \[cs.CL\]](#). 12
- [10] H. Touvron, L. Martin, K. Stone, P. Albert, A. Almahairi, Y. Babaei, N. Bashlykov, S. Batra, P. Bhargava, S. Bhosale, D. Bikel, L. Blecher, C. C. Ferrer, M. Chen, G. Cucurull, D. Esiobu, J. Fernandes, J. Fu, W. Fu, B. Fuller, C. Gao, V. Goswami, N. Goyal, A. Hartshorn, S. Hosseini, R. Hou, H. Inan, M. Kardas, V. Kerkez, M. Khabsa, I. Kloumann, A. Korenev, P. S. Koura, M.-A. Lachaux, T. Lavril, J. Lee, D. Liskovich, Y. Lu, Y. Mao, X. Martinet, T. Mihaylov, P. Mishra, I. Molybog, Y. Nie, A. Poulton, J. Reizenstein, R. Rungta, K. Saladi, A. Schelten, R. Silva, E. M. Smith, R. Subramanian, X. E. Tan, B. Tang, R. Taylor, A. Williams, J. X. Kuan, P. Xu, Z. Yan, I. Zarov, Y. Zhang, A. Fan, M. Kambadur, S. Narang, A. Rodriguez, R. Stojnic, S. Edunov, and T. Scialom, “Llama 2: Open foundation and fine-tuned chat models,” [arXiv:2307.09288 \[cs.CL\]](#). 12
- [11] A. Chowdhery, S. Narang, J. Devlin, M. Bosma, G. Mishra, A. Roberts, P. Barham, H. W. Chung, C. Sutton, S. Gehrmann, P. Schuh, K. Shi, S. Tsvyashchenko, J. Maynez, A. Rao, P. Barnes, Y. Tay, N. Shazeer, V. Prabhakaran, E. Reif, N. Du, B. Hutchinson, R. Pope, J. Bradbury, J. Austin, M. Isard, G. Gur-Ari, P. Yin, T. Duke, A. Levskaya, S. Ghemawat, S. Dev, H. Michalewski, X. Garcia, V. Misra, K. Robinson, L. Fedus, D. Zhou, D. Ippolito, D. Luan, H. Lim, B. Zoph, A. Spiridonov, R. Sepassi, D. Dohan, S. Agrawal, M. Omernick, A. M. Dai, T. S. Pillai, M. Pellat, A. Lewkowycz, E. Moreira, R. Child, O. Polozov, K. Lee, Z. Zhou, X. Wang, B. Saeta, M. Diaz, O. Firat, M. Catasta, J. Wei, K. Meier-Hellstern, D. Eck, J. Dean, S. Petrov, and N. Fiedel, “Palm: Scaling language modeling with pathways,” [arXiv:2204.02311 \[cs.CL\]](#). 12
- [12] J. Su, Y. Lu, S. Pan, A. Murtadha, B. Wen, and Y. Liu, “Roformer: Enhanced transformer with rotary position embedding,” [arXiv:2104.09864 \[cs.CL\]](#). 12
- [13] T. Dao, D. Y. Fu, S. Ermon, A. Rudra, and C. Ré, “Flashattention: Fast and memory-efficient exact attention with io-awareness,” [arXiv:2205.14135 \[cs.LG\]](#). 13, 15
- [14] T. Dao, “Flashattention-2: Faster attention with better parallelism and work partitioning,” [arXiv:2307.08691 \[cs.LG\]](#). 13
- [15] P. Micikevicius, S. Narang, J. Alben, G. Diamos, E. Elsen, D. Garcia, B. Ginsburg, M. Houston, O. Kuchaiev, G. Venkatesh, and H. Wu, “Mixed precision training,” [arXiv:1710.03740 \[cs.AI\]](#). 16
- [16] D. Hendrycks and K. Gimpel, “Gaussian error linear units (gelus),” [arXiv:1606.08415 \[cs.LG\]](#). 19
- [17] M. Shoeybi, M. Patwary, R. Puri, P. LeGresley, J. Casper, and B. Catanzaro, “Megatron-lm: Training multi-billion parameter language models using model parallelism,” [arXiv:1909.08053 \[cs.CL\]](#). 19, 21, 22, 25
- [18] J. Kaplan, S. McCandlish, T. Henighan, T. B. Brown, B. Chess, R. Child, S. Gray, A. Radford, J. Wu, and D. Amodei, “Scaling laws for neural language models,” [arXiv:2001.08361 \[cs.LG\]](#). 29, 30, 31
- [19] J. Hoffmann, S. Borgeaud, A. Mensch, E. Buchatskaya, T. Cai, E. Rutherford, D. de Las Casas, L. A. Hendricks, J. Welbl, A. Clark, T. Hennigan, E. Noland, K. Millican, G. van den Driessche, B. Damoc, A. Guy, S. Osindero, K. Simonyan, E. Elsen, J. W. Rae, O. Vinyals, and L. Sifre, “Training compute-optimal large language models,” [arXiv:2203.15556 \[cs.CL\]](#). 30, 31
- [20] A. Holtzman, J. Buys, L. Du, M. Forbes, and Y. Choi, “The curious case of neural text degeneration,” [arXiv:1904.09751 \[cs.CL\]](#). 33

- [21] R. Pope, S. Douglas, A. Chowdhery, J. Devlin, J. Bradbury, A. Levskaya, J. Heek, K. Xiao, S. Agrawal, and J. Dean, “Efficiently scaling transformer inference,” [arXiv:2211.05102 \[cs.LG\]](https://arxiv.org/abs/2211.05102). 33
- [22] C. Chen, “Transformer inference arithmetic.”.  
<https://kipp.ly/blog/transformer-inference-arithmetic/>. 35
- [23] O. Press, N. A. Smith, and M. Lewis, “Train short, test long: Attention with linear biases enables input length extrapolation,” *CoRR* **abs/2108.12409** (2021) , 2108.12409.  
<https://arxiv.org/abs/2108.12409>. 36
- [24] P. Patarasuk and X. Yuan, “Bandwidth optimal all-reduce algorithms for clusters of workstations,” *Journal of Parallel and Distributed Computing* (2009) . 39
- [25] H. He, “Making deep learning go brrrr from first principles.”. [https://horace.io/brrr\\_intro.html](https://horace.io/brrr_intro.html). 41
- [26] S. McCandlish, J. Kaplan, D. Amodei, and O. D. Team, “An empirical model of large-batch training,” [arXiv:1812.06162 \[cs.LG\]](https://arxiv.org/abs/1812.06162). 42, 43, 44