Attention mechanism
Understand how attention lets transformers focus on relevant context, why it replaced RNNs, and what query-key-value matrices mean at an intuitive level for engineers building AI systems.
TL;DR
- Attention is a learned routing mechanism where every token computes a weighted sum over all other tokens, with weights determined by query-key dot products.
- The core formula is
softmax(QK^T / sqrt(d_k)) * V, but the intuition is simpler: find the best-matching context and retrieve its content. - Multi-head attention runs N parallel attention functions (96 in GPT-3, 120 in GPT-4) so different heads can specialize in different relationship types.
- Self-attention cost is O(n^2) in sequence length. Flash Attention cuts memory to O(n) but does not reduce FLOPs.
- Causal masking prevents decoder tokens from attending to future positions, which is what makes autoregressive generation valid.
- Understanding QKV projections and causal masking is the single highest-signal concept for AI system design interviews.
The problem it solves
Recurrent neural networks (RNNs) and LSTMs process tokens one at a time. Each step updates a fixed-size hidden state vector and passes it forward. By the time the model reaches token 500, everything from token 1 has been compressed, overwritten, and degraded through hundreds of sequential updates.
For a 10,000-word legal contract, the information from the first clause is effectively gone by the time the model reads the last. I've seen teams try to fix this with bidirectional LSTMs and stacked layers, only to hit the same wall: a single hidden state vector cannot encode an entire document.
The second problem is parallelism. RNNs are inherently sequential: token N requires the output of token N-1. You cannot parallelize across the sequence dimension during training. GPU utilization suffers, and training on long sequences is slow regardless of hardware.
Attention fixes both problems. Every token can directly attend to every other token in constant depth (not sequential), and the entire computation parallelizes across the sequence dimension on a GPU.
What is it?
Attention is a learned routing mechanism that computes a weighted sum of context tokens, where the weights reflect each token's relevance to the current query.
Think of it like a library search. You walk to the reference desk with a question (your query). Every book on the shelf has a spine label describing its contents (the key). You scan the spines, find the most relevant ones, and pull those books off the shelf to read their contents (the values). Attention works the same way: the query finds matching keys, and the output is a weighted blend of the corresponding values.
The weights are learned, not hand-engineered. The model learns that a verb at position 47 should attend strongly to its subject at position 3, skipping everything irrelevant in between. No rule told it to do this. The relationship emerges from training data.
For your interview: say "attention computes a weighted average of value vectors, where the weights come from how well each key matches the query." That is the entire algorithm in one sentence.
How it works
Query, key, and value projections
Each token embedding is projected three ways using learned weight matrices W_Q, W_K, and W_V. The result is three vectors per token.
- Query (Q): "What am I looking for?" The current token's representation of what it needs from context.
- Key (K): "What do I offer?" Each context token's advertisement of what information it contains.
- Value (V): "What I'll contribute if selected." The actual content that gets passed forward when a key matches a query.
These projections are separate learned linear transformations. The same token produces different Q, K, and V vectors because each matrix extracts a different facet of the embedding. In my experience, this is where most engineers' understanding breaks down: they confuse the projection step with the scoring step.
Scaled dot-product attention (the math, with intuition first)
Before the formula, here is the intuition. Each query asks "which keys are most similar to me?" The dot product between Q and K measures that similarity (higher dot product means more aligned vectors). We normalize these scores with softmax so they sum to 1 (probability distribution). Then we use those probabilities to blend the value vectors.
The formula: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V
Why divide by sqrt(d_k)? Without scaling, dot products grow proportionally to the dimension size. For d_k = 64, the dot products can reach values like 50 or 60. When these large values hit softmax, the output becomes extremely peaked (one position gets ~1.0, everything else gets ~0.0). The gradients vanish, and the model stops learning. Dividing by sqrt(64) = 8 keeps the variance at 1 and softmax in its useful range.
The scaling factor is not optional
I've seen custom attention implementations skip the scaling factor "because it's just a constant." It is not optional. Without it, training destabilizes after a few thousand steps because softmax saturates and gradients collapse. Always scale by sqrt(d_k).
Multi-head attention
Running a single attention function gives you one routing pattern. Multi-head attention runs N independent attention computations in parallel, each with its own W_Q, W_K, and W_V matrices. Each "head" learns to specialize in a different type of relationship.
Think of it as N different librarians, each expert in a different topic. One head might track syntactic dependencies (verb to subject). Another handles coreference (pronoun to antecedent). Another attends to positional proximity. The outputs of all heads are concatenated and projected back to the model dimension through a final linear layer.
GPT-3 uses 96 heads across 96 layers. GPT-4 reportedly uses 120 heads. Research on head pruning shows that removing any single head rarely hurts performance, but removing many heads degrades quality fast. The redundancy is a feature, not waste.
Causal masking in decoder models
In a decoder-only model (GPT, Claude, Llama), token at position 47 cannot look at token 48 because it has not been generated yet. This is enforced by adding negative infinity to the attention scores for all future positions before softmax. After softmax, those positions become zero weight.
This is called causal masking or autoregressive masking. Without it, the model could "cheat" during training by reading ahead. The mask makes the training objective valid: predict the next token using only previous tokens.
Self-attention vs cross-attention
Self-attention is when a sequence attends to itself (same tokens produce Q and K). Cross-attention is when one sequence (the decoder) attends to another (the encoder output). Decoder-only models like GPT use self-attention exclusively. Encoder-decoder models like T5 use both. Don't conflate the two in interviews.
Positional encoding: why attention needs it
Attention is permutation-invariant. If you shuffle all tokens in a sentence, the QKV mechanism cannot tell. "Dog bites man" and "man bites dog" produce the same set of Q, K, and V vectors (just at shuffled positions), and the mechanism has no inherent notion of order.
Continue Reading with Premium
Unlock this article and every other in-depth system design guide on the platform with NotesFromSDE Premium.