tldr; techniques to speed up training and inference of LLMs to use large context window up to 100K input tokens during training and inference: ALiBi positional embedding, Sparse Attention, FlashAttention, Multi-Query attention, Conditional computation, and 80GB A100 GPUs.

Recently there were several announcements about new Large Language Models (LLMs) that can consume an extremely large context window, such as 65K tokens (MPT-7B-StoryWriter-65k+ by MosaicML) or even 100K tokens (Introducing 100K Context Windows by Antropic). In the Palm-2 technical report, Google doesn’t reveal the context size but mentions that they “increase the context length of the model significantly.”

For comparison, the current GPT-4 model can work with the context length of 32K input tokens. And most of the open-source LLMs have a context length of 2K tokens.

That’s impressive since having such a large context length means the prompt can be literally a size of a book. The Great Gatsby is 72K tokens, 210 pages, and 6 hours of reading at a 1.7 min/page speed. So the model can scan and keep this amount of “custom” information to process queries!

I was trying to wrap my head around how that is technically possible, so in this blog post, I collect scattered pieces of information (this thread was the first clue) and cover the following:

  • Why context length matters and why it can be a game changer
  • What are the main limitations in the original Transformer architecture when working with large context lengths
  • The computational complexity of the transformer architecture
  • What optimization techniques currently exist to speed up the transformer and increase the context length up to 100K

Here and later, we use the “context length,” “context window,” and “the number of input tokens” interchangeably, denoting them as n.

The blog post is a bit long, so there is a summary with the main points and tricks:

  • 1st problem is the quadratic time and space complexity of attention layer computations w.r.t. the number of input tokens n.
  • When the embedding size d > n, the 2nd problem is the quadratic time complexity of linear layers w.r.t. embedding size d.
  • 3rd problem is Positional Sinusoidal Embedding used in the original architecture.
  • In Transformer architecture, the shapes of learnable matrix weights are agnostic to the number of input tokens n.
  • So, a trained Transformer in 2K context lengths can consume tokens of any length, even 100K. But the model will not produce meaningful results on 100K tokens during inference if it isn’t trained on 100K.
  • Training the vanilla Transformer on a giant corpus and only on a large context length is unfeasibly expensive due to the quadratic complexity w.r.t to n and d. LLaMA on 2K context length was estimated to be trained for ~$3M. Thus, LLaMA on 100K would cost ~$150M.
  • One option is to train the model on 2K tokens context and then fine-tune it in longer contexts (for example, 65K). But it won’t work with the original Transformer because of the Positional Sinusoidal Encoding.
  • [Trick #1] To address this, remove Positional Sinusoidal Encoding and use ALiBi, a simple and elegant positional embedding that doesn’t hurt accuracy. Then you can train on 2K and fine-tune on 100K.
  • [Trick #2] You don’t need to calculate attention scores between all tokens. Some tokens are more important than others, so Sparse Attention can be used. It will speed up both training and inference.
  • [Trick #3] Flash Attention efficiently implements the attention layer for GPU. It uses tiling and avoids materialization of big intermediate matrices (n, n) that doesn’t fit into GPU SRAM. It will speed up both training and inference.
  • [Trick #4] Multi-Query attention instead of Multi-Head attention. That means you share weights across all heads when linearly projecting K and V. It dramatically speeds up incremental inference.
  • [Trick #5] Conditional computation avoids applying all model parameters to all tokens from the input sequence. CoLT5 applies heavy computations only to the most important tokens and processes the rest of the tokens with a lighter version of layers. It will speed up both training and inference.
  • [Trick #6] To fit a large context, you need a lot of RAM in GPU, so people use 80GB A100 GPUs.

To sum up, the more you speed up the training and inference, the larger the context length you can use.

Let’s now discuss all these points in more detail.

Context length is one of the critical limitations of LLMs. And increasing it to already 100K is an incredible achievement (I wonder how this statement will look in a year).

One of the important use cases where people want to apply LLMs is “dropping a large pile of custom data into an LLM” (documents related to the company or a particular problem, various heterogeneous texts, etc) and asking questions about this particular data, not some abstract data from the internet that LLM saw during training.

To overcome this limitation now, people do various things:

  • Trying summarization techniques and sophisticated chained prompts
  • Maintaining vector databases to keep embeddings for custom documents and then “searching” across them by some similarity metric
  • Fine-tuning the LLM with custom data when possible (not all commercial LLMs allow that, and it is not an obvious task for open-source LLMs)
  • Developing custom smaller LLMs for this particular data (again, not an obvious task)

Having a large context length allows an already powerful LLM (that saw the whole internet) to look at your context and data and interact with you on a completely different level with a higher personalization. And all these without changing the model’s weights and doing your “training” on the fly, “in memory.” And overall, a large context window brings more accuracy, fluency, and creativity to the model.

One analogy here might be computer RAM, where the operating system keeps the real-time context of all your applications. With a substantial context length, LLM can be like a “reasoning computer,” keeping a lot of user context.

It’s important to note that in Transformer architecture, the shapes of all learnable matrix weights are not dependent on the number of input tokens n. All trainable parameters (embedding lookup, projection layers, softmax layer, and attention layers) do not depend on input length and must handle variable-length inputs. That’s great that we have this out-of-the-box property of the architecture.

That means if you trained a Transformer model with a context length of 2K, you could infer token sequences of any size. The only problem is that the model will not produce meaningful results on 100K tokens during inference if it isn’t trained on 100K context length. In this case, the training data distribution will be far from the one during the inference, so the model will fail as any machine learning model in this setup.

One solution to train a large context length Transformer is to train it in two stages: train the base model on 2K tokens context length and then continue training (fine-tuning) on longer contexts (for example, 65K or 100K). That’s precisely what MosaicML did. But the problem is that it won’t work with the original Transformer architecture, so you need to use some tricks (see Trick #1 later in the post).

Recap on Multi-Head Attention

Challenges of a large context length are related to the computational complexity of the transformer architecture. To discuss the complexity, first, let’s recap how the attention layer works.

Q — queries, K — keys and V — values, notations from the paper relating to the information retrieval, where you insert a “query” to the system and search the closest “key”

n —the input number of tokens

d — text embedding dimension

h — the number of attention heads

k— linear projection size for Q and K

v — linear projection size for V

Multi-Head Attention:

  1. We have a lookup Embedding layer that, for a given token, returns a vector of size (1, d). Thus, for a sequence of n tokens, we get the text embeddings matrix X of size (n, d). Then we sum it up with the Positional Sinusoidal Embedding.
  2. The Multi-Head Attention layer aims to calculate the new embedding for this sequence of tokens that can be considered as an original text encoding X but weighted (1) by relative importance between tokens with regards to the context and (2) by relative positions of tokens.
  3. We process this embedding matrix X (n, d) in parallel with h attention layers (heads). To get Q, K, and V for all attention heads, you linearly project X to k, k, and v dimensions, respectively. You do it by multiplying X by h matrices of shape (d, k), (d, k), and (d, v). You can think about it as multiplying (n, d) by (h, d, k), (h, d, k), and (h, d, v).
  4. Attention Heads return h attention scores matrices of size (n, v). Then we concatenate pieces from all heads (n, h*v) and linearly project it for the next steps.
High-level schema of the attention architecture from the Attention is All You Need paper

Scaled Dot-Product Attention:

Now, let’s zoom in on one attention head.

  1. Q, K, V are 3 linear projections of X of size (n, k), (n, k), and (n, v) obtained by multiplying to learnable weights separate for each head.
  2. We get attention scores by calculating the distance (dot product) between the Q and the K (transposed). You multiply matrix (n, k) by (k, n) and get the matrix (n, n). Then we multiply it by the mask matrix to zero down some of the tokens (required in the decoder). Then we scale it and apply softmax to be from 0 to 1. This way, we get the matrix of shape (n, n) with n_ij – a relative attention score from 0 to 1 between the i-th and j-th token that shows how “close” these tokens are in this particular context of length n.
  3. Then we multiply this attention score matrix (n, n) by “values” V of size (n, d) to get the text embedding weighted by these relative attention scores.
In the original paper, the Attention Score matrix in one head is calculated by this formula.

Let’s look at this piece of code from the Multi-Query attention paper. It shows how the Multi-Head Attention is calculated with batching, and the shapes are clear on every step. They also include masking multiplication used during decoding.

A very nice code showing the shapes of every step in the attention layer. From Multi-Query paper.

The complexity of the Transformer & context length

The complexity of 2 matrix multiplication (a,b)*(b,c) is O(a*b*c).

We assume that k*h = O(d) for simplicity, and we will use this to derive the complexity of the attention.

The complexity of the attention layer consists of two parts:

  1. Linear projections to get Q, K, V: multiplication of embedding matrix of size (n, d) by h learnable matrices (d, k), (d, k), and (d, v). Thus, the complexity ~ O(nd²)
  2. Multiplications of Q by K transformed and then multiplication by V: (n,k) * (k,n) = (n,n) and (n,n)*(n,v) = (n,v). The complexity ~ O(n²d)

So, the complexity of the attention layer is O(n²d + nd²), where n — is the context length (number of input tokens) and d — embedding size. So from here, we see that the complexity of the attention layer computation is quadratic w.r.t the number of input tokens n and quadratic w.r.t embedding size d.

The term O(nd²) is important when d > n (for example, in LLaMa, n=2K and d=4K).

The term O(n²d) is important when n > d (for example, training MosaicML with n=65K and d=4K).

Just to remind you how bad the quadratic growth is:

2 000² = 4 000 000, 100 000² = 10 000 000 000.

Let me give you an example of how this quadratic complexity influences the price of model training. The estimated price of training LLaMa was ~$3M, and it has 65B parameters, 2K context length, and 4K embedding size. The estimated time is mostly GPU training time. If we increase the context length from 2K to 100K (50x), the training time will increase ~50x as well (we need fewer iterations because the context is larger, but it takes longer time on each). So, training LLaMA on 100K context would cost around ~$150M.

A bit of details on this calculation:

For the number of tokens equals n, the complexity of the attention is O(n²d + nd²) and it takes M iterations to train. If we increase the contex length from np*n, it will require M/p iterations since the context length became larger (let’s assume for simplicyty it’s linear, it might be an overestimation or underestimation depending on task). Now we have 2 equations:

(1) Complexity for n ~M * (n²d + nd²)

(2) Complexity for p*n ~ M/p * ((p*n)²d + (p*n)d²)

After a series of simplifiations and divisions, the ratio (2)/(1) ~(d + p*n)/(d + n)

If d << n, increasing n by a factor of p will lead to ~ p times more iterations.

If d ~ n, increasing n by a factor of p will lead to ~ p/2 times more iterations.

Difference between training and inference stages in Transformer

The last thing to discuss before digging into optimization techniques is the difference in computation during training and inference.

During training, you run things in parallel, while for text generation during inference, you need to do it sequentially because the next token depends on previous ones. The straightforward way to implement the inference is to calculate attention scores incrementally and cache previous results for future tokens.

This distinction brings different approaches to speeding up training and inference. That is why some tricks below will optimize both stages, but some will optimize only the inference.

Now, let’s talk about how researchers overcame all these challenges and were able to train an LLM with a large context length.

[Trick #1] Better positional encoding — ALiBi

One solution to train a large context length Transformer is to train it in two stages: train the base model on 2K tokens context length and then fine-tune on longer contexts (for example, 65K). But earlier, we said it wouldn’t work with the original Transformer architecture. Why?

Because of the Positional Sinusoidal Encoding, which has no “extrapolation” ability. In the ALiBI[4] paper, the authors showed that Positional Sinusoidal Encoding is not robust to the extension of the context window during inference. After a few more tokens, the performance starts degrading. So, lack of “extrapolation” ability basically means you can’t use larger context lengths during inference/fine-tuning than during training. The term “extrapolation” and the comparison of various positional encodings are described in [4].

In the original transformer paper, Positional Sinusoidal Embedding has summed with the tokens Embeddings at the bottom of the architecture to add information about the order of words. If you want to learn how the Positional Sinusoidal Embedding is calculated, I recommend this fun video, where it is explained intuitively and in good detail.

So, the first trick is to remove Positional Sinusoidal Embedding and replace it with another position embedding — Attention with Linear Biases (ALiBI).

It is applied in the attention head (not on the bottom of the network), and it biases query-key attention scores with a penalty that is proportional to their distance (before softmax).

This trick speeds up training.

When computing attention scores for each head, ALiBi, adds a constant bias (right) to each attention score (qi · kj , left). As in the unmodified attention sublayer, the softmax function is then applied to these scores, and the rest of the computation is unmodified. m is a head-specific scalar that is set and not learned throughout the training. From ALiBI paper.

[Trick #2] Sparse Attention

Not all tokens in the context of size 100K are relevant to each other. One way to reduce the number of computations is to consider only some tokens when calculating the attention scores. The goal of adding the sparsity is to make the computation to be linear to n, not quadratic. There are several approaches how to select the connection between tokens, and there is an excellent illustration of this in the Google blog post:

Full attention can be viewed as a complete graph. Sparse Attention Methods
Sparse Attention Methods

For example, the Sliding Window Attention (also called Local) employs a fixed-size window attention surrounding each token. In this attention pattern, given a fixed window size of w, each token attends to w/2 tokens on each side. The computational complexity of this pattern is O(n*w), which scales linearly with input sequence length n. To make it efficient, w should be small compared with n. The trick is that the attention information “flows” the whole context window within near tokens, approximating the full graph.

The BigBird attention score method combines global, local, and random mechanisms. In the paper, the authors showed a crucial observation that there is an inherent tension between how few similarity scores one computes and the flow of information between different nodes (i.e., the ability of one token to influence each other).

This trick speeds up both training and inference.

[Trick #3] FlashAttention — efficient implementation of the attention layer for GPU

There are several computational operations in the attention layer are repeated over and over again:

  1. S = Q*K
  2. P = softmax(S)
  3. O = P*V

Remember the notion for P, S and O results; we will use it later. FlashAttention authors “fused” these operations: they implemented an attention layer algorithm that utilized the GPU memory efficiently and calculated the exact attention.

For a GPU to make an operation, the input data must be present in the “quick” memory named SRAM. The data is copied from “slow” HBM memory to SRAM and returned back to HBM once the computation is over. SRAM memory is much faster than HBM but much smaller in size (20MB vs 40GB in A100 40GB GPU).

A100 GPU Memory Hierarchy. FlashAttention paper

So, accessing the HBM is an expensive operation.

The main problem in the attention layer w.r.t the GPU memory utilization is “intermediate” multiplication results, P, S, and O, that are large in size (n, n). We need to save them to HBM and read them again between attention operations. Moving P, S, and O from HBM to SRAM back and force is the bottleneck, which the authors solved in the paper.

The main idea behind the FlashAttention algorithm is to split the inputs Q, K, and V matrices into blocks, loading these blocks from HBM to SRAM and then computing the attention output w.r.t those blocks. This procedure is named tiling.

Left: FlashAttention uses tiling to prevent materialization of the large n × n attention matrix (dotted box) o HBM. In the outer loop (red arrows), FlashAttention loops through blocks of the K and V matrices and loads them to SRAM. In each block, FlashAttention loops over blocks of Q matrix (blue arrows), loading them to SRAM, and writing the output of the attention computation back to HBM. Right: 7.6× speedup. FlashAttention paper

The “matrix multiplication” operation is already optimized for GPU. You might think of this FlashAttention algorithm as implementing the “attention layer” operation optimized for GPU. The authors “fused” operations of several multiplications and softmax with tiling and optimized HBM accessing.

There is a good overview of the FlashAttention paper.

Since recently, PyTorch 2.0 has flash-attention built-in. This is the FlashAttention implementation in Triton language by the authors.

This trick speeds up both training and inference.

[Trick #4] Multi-Query attention (MQA)

The original Multi-Head Attention (MHA) has a separate linear layer for K and V matrices in every head.

During inference, the keys and values of previous tokens in the decoder are cached to prevent re-computing them, so GPU memory usage grows with each generated token.

Multi-Query attention (MQA) is the optimization that suggests sharing weights across all attention heads when linearly projecting K and V, so we would need to keep only 2 matrices of size (n, k) and (n, v). A big model can have up to 96 heads (such as GPT-3) which means using MQA can save 96x the memory consumption of the key/value decoder cache.

This optimization is especially beneficial when generating long texts. For example, having a large context length and asking for a long, meaningful analysis or summarization.

The main advantage of this approach is the significant speeding up of the incremental attention scores calculation during inference. Training speed stays mostly the same. For example, PaLM is using it.

[Trick #5] Conditional computation

When d > n, the bottleneck in speed is not the attention layer but the feedforward and projection layers. A common approach to reducing the FLOPs is employing some form of conditional computation that avoids applying all model parameters to all tokens from the input sequence.

In the Sparse Attention section, we’ve discussed that some tokens are more important than others. Following the same intuition, in the CoLT5 paper, authors separated all feedforward and attention computations into two branches: heavy and light. Lite layers are applied to all tokens, and the heavy ones only to important ones.

“The light and heavy feedforward branches differ only in their hidden dimension, with the light branch having a smaller hidden dimension than the standard T5 feedforward layer and the heavy branch larger”.

This approach has been shown to outperform both the speed and accuracy of the existing LongT5 model for extremely long sequences up to 64K input tokens.

An overview of a COLT5 Transformer layer with conditional computation. All tokens are processed by light attention and MLP layers, while q routed query tokens perform heavier attention over v routed keyvalue tokens and m routed tokens are processed by a heavier MLP. CoLT5 paper

[Trick #6] Large RAM GPUs

It’s not a trick but a necessity. To fit a large context, you need large RAM in GPU, so people use 80GB A100 GPUs.

Wow, that’s a lot. I didn’t expect to end up with such a long blog post :D

I hope it was helpful! I learned a lot, and I hope you did too, and now we can guess how these Large Language Models with billions of parameters were trained in unprecedented context windows of 65-100K tokens.

Inspiring to see how different smart people address the same problem from different sides, optimize here and there, and come up with cool ideas. All these lead to a meaningful and elegant solution.

I like what one Researcher said about training the LLM with a large context: “No secret sauce, just well-vetted research.”

[1] Introducing 100K Context Windows by Antropic

[2] MPT-7B by MosaicML

[3] Palm-2 Technical report by Google

[4] ALiBI: Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation

[5] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

[6] Multi-Query attention: Fast Transformer Decoding: One Write-Head is All You Need

[8] Attention is All You Need

[9] Video on Positional Sinusoidal Embedding

[10] Overview of the FlashAttention paper

[11] Sliding Window Attention

[12] Constructing Transformers For Longer Sequences with Sparse Attention Methods

[13] FlashAttention implementation in Triton language

[14] How to Accelerate HuggingFace Throughput by 193% with Triton and ClearML

[15] ClearML Serving

[16] Analyzing the Pros and Cons of NVIDIA Triton Inference Server vs. Other Inference Engines

[17] COLT5: Faster Long-Range Transformers with Conditional Computation

[18] LongT5: Efficient Text-To-Text Transformer for Long Sequences

[19] PaLM

[20] BigBird attention mechanism

Read More