LLM Inference: From Black Box to Production
A ground-up explanation of LLM inference, from black box to production optimizations. Covers tokenization, attention, KV cache, memory bottlenecks, batching, PagedAttention, quantization, and more. No code, just diagrams and TinyLlama as our running example.
What This Post Covers
This post explains how inference works in large language models: the process of taking a prompt and generating a response. We start with the simplest possible view and progressively open the box, layer by layer, until you understand what's actually happening inside.
No prerequisites. Grab a coffee and settle in. We start from zero.
Training vs Inference: A Quick Note
If you've heard about LLMs, you've probably heard about "training" (or "pre-training"), the process where models learn from massive datasets. Inference is different. During training, the model adjusts its internal parameters by learning from mistakes (backpropagation). During inference, the model is frozen. It just runs forward, producing outputs without learning anything new.
This distinction matters for understanding what follows:
- No gradients: Inference doesn't compute or store gradients. No learning happens.
- No optimizer state: Training requires storing momentum, variance, and other optimizer data. Inference doesn't.
- Simpler memory profile: Inference only needs the model weights plus working memory for the current computation.
- Forward passes only: We just compute outputs, never backpropagate.
Everything in this post is about inference: using an already-trained model to generate text.
Our Running Example: TinyLlama
We use TinyLlama 1.1B as our running example throughout this post. It's small enough to reason about concretely but architecturally identical to larger models like the latest OpenAI or Anthropic model. We'll introduce its specific parameters (like embedding dimensions and layer counts) as they become relevant. You don't need to memorize anything upfront.
If you want to verify any numbers or explore the architecture yourself, see the TinyLlama model card on HuggingFace. The config.json file there contains all the architecture parameters we'll reference.
By the end of this post, you'll understand the inference pipeline from text input to text output, why generating tokens is fundamentally limited by memory bandwidth rather than compute, and the production techniques that vLLM, TGI, and other systems use to make inference fast.
Part 1: How Inference Works
The next 12 sections build up from black box to full pipeline, one concept at a time.
A Black-Box View
From a user's perspective, an LLM is just a black box. You type "Write a story" and a moment later see "Once upon a time...". Nothing in the interface hints at tokens, matrices, or GPUs. The mental model is simple: text goes in, text comes out.
This diagram is deceptively simple, but worth pausing on. Everything we build from here is opening this box layer by layer.
Detail 1: Tokens
We could think of language as a hierarchy: ideas → sentences → words → letters. LLMs, however, don't work at the letter or whole-word level. Instead, they use an intermediate unit called a token, a chunk of text that's bigger than a character but often smaller than a full word.
For example, where we see one word "responding," the model might see two tokens: ["respond", "ing"]. To us it's one word; to the model it's a sequence of two tokens.
The tokenizer is a deterministic algorithm (not learned) that maps text to integers using a fixed vocabulary. This vocabulary is built before training by analyzing a large corpus and finding common subword patterns. The algorithm (usually Byte-Pair Encoding or SentencePiece) balances vocabulary size against token length: common words get their own tokens, rare words get split into pieces.
TinyLlama has a vocabulary of 32,000 tokens. A prompt like "Write a story" might become:
["Write", " a", " story"] → [8144, 264, 3446]
From this point on, the model only manipulates these integers. The text is gone.
Try it yourself: You can experiment with tokenization at tiktokenizer.vercel.app. Paste in some text and see how different models split it into tokens.
At the output end, a detokenizer reverses this: token ID → text chunk.
But wait, how does the model actually understand what these numbers mean? Token 8144 and token 8143 are adjacent integers, but they could represent completely unrelated words. We need a way to give these numbers meaning.
Detail 2a: Embeddings
After tokenization, we have integers like [9566, 261, 4869]. But these are just arbitrary IDs. Token 9566 and token 9565 might be completely unrelated words, yet their IDs are adjacent. The integer itself carries no meaning; it's just a position in a lookup table. You can't do useful math on raw IDs.
What we need is a representation where similar tokens have similar values. "Lion" and "tiger" should be numerically close; "lion" and "teacup" should be far apart. That's what embeddings provide.
An embedding converts each token ID into a dense vector of floating-point numbers. To build intuition, imagine just 2 dimensions. One dimension could encode "danger level," the other "size." Animal words might land in this 2D space like:
Notice that semantically similar concepts cluster together. Lion and scorpion are both dangerous, so they're on the same side of one axis. Elephant and lion are both large, so they're on the same side of the other. The geometry captures meaning.
TinyLlama uses 2,048-dimensional embeddings. So instead of representing "Write" as integer 9566, we represent it as a vector of 2,048 numbers: [0.21, -0.83, 0.45, ..., 0.12]. Every token in the vocabulary gets its own 2,048-dimensional vector.
If meaning is encoded in vectors, we can operate on meaning with math. The direction from "mouse" to "elephant" represents "getting larger while staying harmless." Add that direction to "scorpion" and you might land near "crocodile," something large and dangerous.
System implications: With a 32,000-token vocabulary and 2,048-dimensional embeddings stored as FP16 (2 bytes each):
32,000 × 2,048 × 2 bytes = 131 MB
This embedding table is a matrix of shape [32,000 × 2,048]. When the model sees token ID 9566, it looks up row 9,566 and retrieves that token's vector. That vector, not the original ID, is what flows through the rest of the model.
We now have rich semantic vectors for each token. But what happens at the output end? Eventually, we need to get back to text. That's next.
Detail 2b: Unembeddings
We've converted tokens to rich vectors. But eventually, the model needs to produce output. How does it go from vectors back to tokens?
The transformer layers (still a "black box" for now) process our embedded vectors and output a final 2,048-dimensional vector for the last position. This vector encodes everything the model "wants to say" next, but it's not a token yet. We need to convert it back to a choice from our 32,000-token vocabulary.
The LM Head
The LM head (language modeling head) performs this conversion. It's a linear projection: a matrix of shape [2,048 × 32,000] that multiplies the final hidden state to produce 32,000 numbers, one for each token in the vocabulary.
final_vector × W_lm_head = 32,000 scores
[1 × 2,048] [2,048 × 32,000] [1 × 32,000]
These 32,000 numbers are called logits: raw, unnormalized scores. They can be any real number, positive or negative. A logit of 5.2 for "Once" doesn't mean 5.2% probability; it's just a score. Higher logits indicate the model thinks that token is more likely to come next.
Softmax
To turn logits into actual probabilities, we apply softmax:
Softmax does two things:
- Exponentiates each score (making them all positive)
- Normalizes so they sum to 1
After softmax, we have a probability distribution over all 32,000 tokens. "Once" might have probability 0.08, "The" might have 0.12, "In" might have 0.05, and so on.
Sampling: Picking the Next Token
Now that we have a probability distribution, how do we pick one token? Your first instinct might be: just pick the most probable one. If "The" has 12% probability and everything else is lower, pick "The." Simple, deterministic, optimal... right?
This approach is called greedy decoding, and it sounds reasonable. But it has a problem: it produces boring, repetitive text. Consider what happens when the model generates "The cat sat on the mat. The cat sat on the..." Once the model produces a phrase, greedy decoding tends to repeat it. The same context leads to the same prediction, which leads to the same context, forever. The model gets stuck in loops.
There's a deeper issue too. Language isn't deterministic. When you write, you don't always pick the "most likely" next word. Sometimes you choose a surprising word that takes the sentence in an interesting direction. Greedy decoding can't do this. It always takes the safe, expected path.
The solution is to sample from the probability distribution instead of just taking the maximum. If "The" has 12% probability and "Once" has 8%, we might pick "Once" sometimes. This introduces controlled randomness that produces more natural, varied, and interesting text.
But pure random sampling has its own problem too: sometimes it picks very unlikely tokens (the ones with 0.001% probability), producing incoherent nonsense. So we use strategies that balance randomness with quality:
- Temperature: Scale the logits before softmax. Temperature > 1 makes the distribution flatter (more random, more creative). Temperature < 1 makes it peakier (more deterministic, more focused). Temperature = 0 is equivalent to greedy decoding.
- Top-k sampling: Only consider the k highest-probability tokens, then sample among them. This filters out the long tail of unlikely tokens while preserving variety among the good options.
- Top-p (nucleus) sampling: Only consider tokens whose cumulative probability exceeds p (e.g., 0.9). Unlike top-k, this adapts to the distribution. If the model is very confident, only a few tokens are considered. If it's uncertain, more tokens make the cut.
The Complete Output Path
Putting it together, the output side of the model works like this:
- Transformer outputs a 2,048-dim vector for the last position
- LM head projects it to 32,000 logits
- Softmax converts logits to probabilities
- Sampling picks one token from the distribution
- Detokenizer converts the token ID back to text
Embedding and the LM head are bookends: one converts a token ID into a rich vector (embedding), the other converts a rich vector back into a token ID (unembedding). The transformer layers between them are where the actual "thinking" happens.
We now understand both ends of the pipeline: how tokens become vectors, and how vectors become tokens. But there's still something missing from our embedded vectors.
Detail 3: Positional Encodings
We now have rich semantic vectors for each token, but there's a problem: the model has no idea where each token appears. If you only use embeddings, "dog bites man" and "man bites dog" look nearly identical. Same vectors, different order. The embeddings tell the model what each token is, not where it is, even though word order completely changes meaning.
Positional encodings fix this. Before the embeddings enter the transformer layers, we inject position information into each vector. Different architectures do this differently: fixed sinusoidal patterns (original Transformer), learned position embeddings (GPT-2), or rotation-based encodings like RoPE (TinyLlama, Llama 2, most modern models).
The mechanics vary, but the core idea is the same: the embedding for "Write" at position 0 becomes slightly different from "Write" at position 50, because position has been baked into the vector.
TinyLlama uses RoPE (Rotary Position Embedding), which applies a rotation to the embedding based on position. The rotation angle depends on both the position index and the dimension index within the vector. This has a nice property: the dot product between two position-encoded vectors naturally encodes their relative distance.
After this step, the model sees a sequence of vectors that encode both meaning and position.
Detail 4: Autoregressive Generation
Here's a critical point about LLM inference: the model generates one token at a time. When you ask for a story and see "Once upon a time there was a dragon," that didn't appear in a single step. The model first chose "Once", then "upon", then "a", then "time." Each choice is conditioned on everything before it.
This is called autoregressive generation. The model is a next-token predictor. At each step, it takes the current sequence, runs it through the full pipeline we've discussed (embedding → processing → LM head → sample), gets one token, appends it, and repeats.
The process:
- Start with prompt tokens: [9566, 261, 4869] ("Write a story")
- Feed to model → get probability distribution over 32,000 tokens
- Sample one token: 12483 ("Once")
- Append: [9566, 261, 4869, 12483]
- Feed this longer sequence back → get next distribution
- Sample: 15234 ("upon") → append
- Continue until a stop token appears or max length reached
Each iteration is a forward pass: the model takes the current token sequence, runs it through all 22 layers, produces logits via the LM head, and samples one token. On step N, the model sees N tokens: the original prompt plus the N−1 tokens already generated.
How does generation stop? The vocabulary includes special tokens like <EOS> (end of sequence) or </s>. When the model samples one of these, generation terminates. The model learned during training that these tokens signal "I'm done responding." Alternatively, you can set a maximum output length. Generation stops when that limit is reached, even if no EOS token appeared.
What's the context window? The context window (or context length) is the maximum number of tokens the model can "see" at once. TinyLlama's context window is 2,048 tokens. If your prompt plus generated output exceeds this, older tokens get truncated. The model literally forgets the beginning of the conversation. This is why long documents or extended chats eventually "lose" earlier context.
System implications: If your prompt is 100 tokens and you want a 200-token completion:
- You perform 200 forward passes
- Pass 1 processes 100 tokens
- Pass 2 processes 101 tokens
- ...
- Pass 200 processes 299 tokens
Generating a 200-token response requires 200 separate trips through the entire model. Each trip touches all 1.1 billion parameters. That's 220 billion parameter reads just to write a few sentences.
Look at the diagram above. Each row is one forward pass. The blue tokens appear on every single row, unchanged. On step 4, "Write", "a", and "story" are processed for the fourth time. On step 200, they'd be processed for the two-hundredth time. The only thing that's actually new on each row is the single green token at the end.
On step 150, compared with step 149, the model runs a full forward pass over 152 tokens. 151 of them are identical to the previous step. One token is new. Everything else is recomputation.
This points to a natural split. The prompt tokens are all known upfront and could be processed in a single parallel pass. After that, each new token depends on previous output, so generation is inherently sequential. These are two different workloads with very different performance characteristics.
Later we'll give them precise names ('prefill' and 'decode'). And we'll see the data structure that eliminates the redundancy too ('KV cache'). The idea is straightforward once you see it: compute something once for each token position, store it, and never recompute it. Prefill builds that stored state for the prompt. Decode extends it one token at a time.
But to understand what gets stored and why it doesn't change, we need to open the transformer block and look at the computation inside.
Detail 5: Self-Attention
In Detail 4 we saw the autoregressive loop: embed, process, project to logits, sample, repeat. The "Processing" step was a black box. Time to open it.
Inside that box is a stack of transformer blocks. TinyLlama has 22 of them, each structurally identical but with different learned weights. The embedded, position-encoded vectors enter block 1, flow through all 22 blocks in sequence, and exit into the LM head for next-token prediction.
Each block has two main components:
- Self-Attention: where tokens "communicate" with each other
- Feed-Forward Network (FFN): where each token is processed independently
Think of it this way:
- Attention asks: "Given what I know, what context from other tokens is relevant?"
- FFN asks: "Now that I have context, what should I transform this into?"
For the next few sections, we focus only on what's inside one block. The rest of the pipeline stays unchanged. We start with attention.
Why Attention?
We spoke about embeddings and how they give you the meaning of a word. I like to think of the embedding table as a dictionary where we look up a word (alphabetically in a real dictionary, by token ID in an LLM) and get back a "definition" in the form of a vector.
But what if a word means different things in different contexts? What if we come across the word "bat", which could be an animal or a piece of sports equipment? A dictionary handles this by listing multiple definitions across the most widely used contexts. You could say a dictionary gives you a generic sense of what the word may mean, covering the common cases.
When we actually use a word in a sentence, though, we need its meaning in that specific context, not some generic overview. And embeddings fall short here. The embedding for "bat" is the same vector whether the sentence is about sports or a cave. We need something that takes the generic dictionary definition and resolves it (the way a human reader does) into what the word means right here, in this sentence. If embeddings give us "definition," we need something that gives us "contextual definition."
That's what self-attention does. But before we get into its mechanics, let's establish what we'd want from such a mechanism:
- It must take in the generic embedding of each token.
- It must somehow look at the other tokens in the sequence and update its own representation based on context.
- It must do this efficiently.
Let's take an example. Consider this sentence: "The animal didn't cross the street because it was too tired." What does "it" refer to? To us, it's straightforward, "it" means the animal, not the street. But a model processing tokens left-to-right has no inherent way to know this. We need a mechanism that lets "it" look back at previous tokens and figure out which ones are relevant. That mechanism is attention.
Building Attention from First Principles
Rather than introducing the attention formula and then explaining what it does, let's derive it. We'll start from the most naive approach, see where it breaks, and fix it piece by piece until we arrive at the real thing.
Let's return to the "bat" problem. We said the embedding table gives "bat" a single, generic vector, which is a compromise that sits somewhere between the animal meaning and the sports equipment meaning. Now let's see how attention resolves it. Take the sentence "The bat flew." After embedding, we have three vectors:
embed("The") = [0.04, -0.11, 0.52, ..., 0.03] (2,048 numbers)
embed("bat") = [0.61, 0.45, -0.33, ..., 0.27]
embed("flew") = [0.58, 0.39, 0.21, ..., -0.14]
Each vector encodes that token's meaning in isolation. "bat" knows it could mean an animal or sports equipment, but it doesn't know which one applies here. "flew" knows it means movement through the air, but it doesn't know what's flying. We need a way for each token to look at the others and "pull" in relevant context.
Step 1: Measuring Similarity
We have vectors that encode meaning, and we know from the embedding section that similar meanings produce similar vectors. So the first question is: given two embedding vectors, how do we express how similar they are as a single number?
Think about what "similar" means for vectors. Two vectors pointing in the same direction represent similar meanings. Two vectors pointing in perpendicular directions are unrelated. Two vectors pointing in opposite directions are dissimilar. The angle between them captures this perfectly.
The cosine of that angle gives us a clean similarity score: 1 means same direction (identical meaning), 0 means perpendicular (unrelated), −1 means opposite (dissimilar). This is called cosine similarity.
How do you actually compute the cosine of the angle between two vectors? The dot product of two vectors, divided by their magnitudes, gives exactly the cosine of the angle between them. In practice, we often skip the magnitude normalization and use the raw dot product directly (likely because it's simpler, and the magnitudes carry useful information too).
But, you may wonder, what about our 3rd criteria from above? How is this "efficient"?: Well, computing the dot product between every pair of tokens is just a matrix multiplication. If is a matrix where each row is a token embedding, then gives us all pairwise similarity scores in one operation.
Let's focus on "bat." We compute its dot product with every token in the sequence:
# How much is the word "bat" related to each of the words in "The bat flew"?
# How related is "bat" with "The"?
score("bat", "The") = embed("bat") · embed("The") = 0.3
# How related is "bat" with itself?
score("bat", "bat") = embed("bat") · embed("bat") = 5.8
# How related is "bat" with "flew"?
score("bat", "flew") = embed("bat") · embed("flew") = 6.1
These scores tell us: "bat" is barely related to "The" (score 0.3), quite similar to itself (score 5.8), and most similar to "flew" (score 6.1), both relate to animals and motion. To turn these into proper weights, we normalize with softmax (same function from the LM head section) so they sum to 1:
weight("bat", "The") = 0.02
weight("bat", "bat") = 0.43
weight("bat", "flew") = 0.55
This is like reading all the definitions for a word in a dictionary, glancing at the sentence it appears in, and collapsing them into the one meaning that fits. We are doing exactly that: compute a new, context-aware definition of "bat" as a weighted sum of all the embeddings:
new_bat = 0.02 × embed("The") + 0.43 × embed("bat") + 0.55 × embed("flew")
This new_bat vector now carries information from the other tokens, weighted by relevance. "flew" contributed the most. It's the word that 'disambiguates' the word "bat" toward its animal meaning. "The" is a function word with almost no semantic content, so it contributed nearly nothing.
Notice what just happened. "bat" started with a generic embedding, a compromise vector sitting somewhere between the animal meaning and the sports equipment meaning. After this weighted sum, "flew" pulled the vector toward the animal cluster. The result is no longer a generic dictionary definition. It's a contextual definition: "bat, the flying animal."
In this example, raw dot products did the job. The disambiguation we needed aligned with semantic similarity: "flew" is semantically closer to the animal meaning of "bat" than to the sports equipment meaning, so the dot product captured exactly the right signal.
(In reality, these embeddings live in 2,048 dimensions, not the 2D picture you might be imagining. The shift happens across many dimensions simultaneously. But the intuition of "movement toward the right meaning" holds.)
We do this for every token position. Each token computes similarity scores against all others, normalizes them, and takes a weighted sum. After this operation, every token's vector has been enriched with context from the rest of the sequence.
This is a perfectly valid form of attention. It captures the core idea: each token looks at every other token, measures relevance, and pulls in information proportionally.
However, while it does work to some degree, it has a fundamental limitation.
Step 2: Why Raw Dot Products Aren't Enough
The dot product of raw embeddings measures semantic similarity, meaning how often these words appear in similar contexts. But attention often needs to capture relationships that have nothing to do with semantic similarity.
Go back to: "The animal didn't cross the street because it was too tired." The word "it" needs to attend heavily to "animal." But "it" and "animal" are semantically very different words. "It" is a pronoun. "Animal" is a concrete noun. Their raw embeddings won't be similar. They appear in completely different linguistic contexts during training.
What we actually want is for "it" to ask "who am I referring to?" and for "animal" to signal "I'm a noun that can be referred to." That's not a question about semantic similarity. It's a question about grammatical and contextual role, the kind of relationship these words have in this sentence.
Raw dot products can't learn this. They're locked into whatever similarity the embedding space already encodes. We need the model to learn what to compare: to transform the vectors before comparing them, so that the comparison exposes the right kind of relationship for the task at hand.
Step 3: Queries, Keys, and Values
Imagine you're hiring an ML infrastructure engineer. Three candidates sit across from you: an infra engineer with five years at a big tech company, a strong generalist software engineer, and an ML researcher fresh out of a PhD.
The outcome of your evaluation is a weighted blend of what each candidate brings:
hiring_outcome = 0.7 × infra_engineer + 0.2 × generalist + 0.1 × researcher
The infra engineer is the strongest match for your need, so they contribute the most. The generalist brings something useful but less targeted. The researcher contributes least to this particular role.
Focus on one candidate. We need two things: how relevant the infra engineer is to your need, and what they actually contribute if hired.
hiring_outcome = (how_relevant_infra_engineer_is) × what_infra_engineer_actually_brings
Simplify:
hiring_outcome = (weight) × value_of_infra_engineer
This equation has two unknowns: the weight and the value. Your instinct might be that they're closely related. Both involve the infra engineer after all! But they serve fundamentally different purposes.
Look at the hiring process closely. Three distinct things are at play:
(a) Your job spec: "someone who combines ML knowledge with infrastructure experience." Structured for matching against candidates.
(b) Each candidate's resume: the infra engineer highlights cloud deployments, the generalist highlights system design, the researcher highlights ML publications. Optimized for comparison against job specs, not a description of what they'd actually do on the job.
(c) Each candidate's actual contribution if hired: the infra engineer writes the modules, debugs the 3 AM outages, and knows which Kubernetes configs silently break under load. The generalist spots architectural bottlenecks early and writes code that the rest of the team can actually maintain. The researcher catches a flawed training setup before it burns a week of GPU time.
Each candidate is one person, but the process needs three distinct representations: one of your need ("I'm looking for these things"), one of each candidate for matching ("here is what I have that may help you"), and another of the candidate for the actual work ("this is how I actually contribute"). Now we can label our equation properly:
hiring_outcome = match(your_job_spec, candidate_resume) × candidate_actual_contribution
= weight × value_of_infra_engineer
weight comes from comparing your
(a) job spec
(b) against the candidate's resume
(c) value_of_infra_engineer is their actual contribution
The resume gets them matched; the actual work is what they deliver. Different things, computed separately.
Same thing with tokens. Each token starts as one embedding vector, but attention needs three forms of it:
- What it's looking for from context (like a job spec). What kind of surrounding information would help disambiguate this token? This is the Query.
- How it describes itself for matching (like a resume). When other tokens are looking around for context, what should they see when they look at this token? This is the Key.
- What it actually contributes when selected (like the real work a hired person does). The content it shares when pulled into the weighted sum. This is the Value.
Why three forms, not one? Same reasoning as the analogy. If a single vector had to serve as both the matching criteria and the actual contribution, those two roles would be locked together. The resume is optimized for matching; the actual work is what matters after the match. Separating them gives the model flexibility to match on one set of features and contribute a different set. The comparison is also asymmetric: your spec matched against Candidate A's resume is a different question than Candidate A's spec matched against yours. Two separate transformations for Query and Key capture directional relationships.
Each form is produced by multiplying the raw embedding by a learned weight matrix:
Query = input × W_Q
Key = input × W_K
Value = input × W_V
Three separate matrices (, , ), all learned during training. The model discovers what aspects of each embedding to expose for each role.
Putting it all together for a given token:
- Compute this token's Query. Compute every token's Key and Value.
- Dot product of this token's Query against each Key to get relevance scores.
- Softmax to normalize scores into weights summing to 1.
- Weighted sum of Values to produce a new, context-aware embedding.
Three learned projections give the model independent control over what to look for (Query), what to advertise for matching (Key), and what to share when selected (Value). We started from raw dot products and arrived here by fixing limitations one at a time. The result is a mechanism that resolves ambiguity using surrounding context, elevating a generic, dictionary-esque meaning to a contextual one.
The Attention Formula
Let's express what we just derived as math, one piece at a time.
Start with the matching step. The hiring manager compares her job spec (Query) against each candidate's resume (Key). That comparison is a dot product. For all tokens at once, it's a matrix multiplication: each Query dotted against every Key.
Each entry in this table tells us how well one token's Query matches another token's Key. A large number means strong match, a small number means weak match. That's the matching step, expressed as a single matrix multiply.
But knowing who matches well isn't enough. We still need to collect what those matched tokens actually contribute. That's what the Values are for. The hiring manager doesn't just rank candidates by fit and stop there. She hires them, and they do work. Similarly, after computing how relevant each token is, we need to gather the content those tokens offer. We do that by multiplying by :
Each Value gets scaled by how strong its match was. Tokens that matched well contribute a lot. Tokens that matched poorly contribute almost nothing.
There's a problem with this, though. Think about what happens when we add up these contributions. One match produces a number like 50, another produces 2. The 50 dominates everything else. The output is at the mercy of whatever raw magnitudes the dot products happened to produce.
What we want is proportions. Each token should contribute its share of the total relevance, and those shares should add up to 1. There is a mathematical operation called softmax that does this. It takes a row of numbers and converts them into proportions summing to 1, preserving the relative ordering:
This works. But there's a lurking issue.
In the "bat" example, our dot products were small: 0.3, 5.8, 6.1. We only had a handful of dimensions. TinyLlama uses 64 dimensions per head. A dot product sums one term per dimension, so 64 dimensions means summing 64 terms. The results are naturally much larger.
When the numbers going into softmax are large, softmax becomes extreme. It puts nearly all the weight on the single largest value and nearly zero on everything else. Attention stops blending information from multiple tokens and collapses into picking just one.
We need to shrink the dot products before they hit softmax. Divide by something:
What should we divide by? A fixed constant like 100 would work for one model but break for another with different dimensions. We need something that adapts.
The dot products got large because we're summing more terms. With 3 dimensions, we sum 3 terms. With 64, we sum 64. The number of dimensions () drives the growth. So the scaling factor should depend on .
But dividing by directly overshoots. It shrinks things more than necessary.
To find the right amount, think about what each term in the dot product does. Each one adds a little randomness to the total. With 3 terms, the total doesn't wander far from zero. With 64, it wanders further. How far? In statistics, this is measured by variance: how much a value typically deviates from its average. For a sum of independent terms, each term adds its own variance to the total. So with terms, the total variance is about times larger.
But variance measures spread in squared units. Think of it like area vs. side length. A square with area 64 has sides of length 8, not 64. Same idea: if the variance grew by a factor of , the actual dot product values only grew by . That's what we divide by.
That's the full attention formula. Every piece maps to a step we already derived: is the Query-Key matching, softmax turns match strengths into proportions, is the actual contribution, and keeps the math stable.
For TinyLlama, Q, K, and V are computed by multiplying the input by learned weight matrices:
Q = input × W_Q (2,048 → 2,048)
K = input × W_K (2,048 → 2,048)
V = input × W_V (2,048 → 2,048)
So far we've described attention with a single set of Q, K, V projections. That works, but it's limiting: a single attention computation can only capture one type of relationship per layer. What if the model needs to simultaneously track syntax, semantics, and position? Next, we'll see how splitting attention into multiple independent "heads" solves this.
Detail 6: Multi-Head Attention
We built attention with one set of learned projections: one , one , one . Each token gets one Query, matches it against every Key, and collects a single weighted blend of Values. One question, one answer, one blended output vector.
Where does this break?
Extend the "bat" sentence: "The bat flew out of the cave and landed on a branch." Focus on the word "landed." To process this word well, the model needs several pieces of context at once:
- Who landed? It should attend to "bat" (subject tracking).
- How did it get here? It should attend to "flew" (action sequence).
- Where from? It should attend to "cave" (spatial origin).
These are three genuinely different questions. But with one set of attention weights, "landed" produces one set of numbers that must sum to 1. If it puts weight 0.4 on "bat", 0.3 on "flew", and 0.2 on "cave", those signals all get averaged into a single output vector. The model can't cleanly separate "the subject is bat" from "this follows flying" from "the origin is cave." It gets one blurry compromise.
With short sentences, this averaging might be good enough. But real sequences have dozens or hundreds of tokens, and the model often needs to simultaneously track syntax, coreference, temporal structure, locality, and patterns we can't even name. Cramming all of that into one set of weights is asking a single number line (summing to 1) to encode multiple independent relationships. Something has to give.
The Fix
What if we just ran attention multiple times in parallel?
Same input tokens. But each copy gets its own , , , so each copy can ask a different question. One copy might learn to track subjects. Another might learn to track action sequences. Another might focus on nearby tokens for local structure. Each produces its own output, and we combine them at the end.
Return to the hiring analogy from Detail 5. There, one hiring manager evaluated candidates against one rubric. That works for one role. But building a team means evaluating candidates along multiple dimensions simultaneously. So you form an interview panel: one interviewer focuses on infrastructure depth, another on communication skills, another on ML intuition. Same candidate pool, different rubrics, separate evaluations. The final decision synthesizes all their assessments.
Each of these independent attention computations is one interviewer on the panel. The standard term for each one is a head. Running multiple heads in parallel is multi-head attention.
How Heads Fit Inside the Vector
You might worry about cost. Running 32 separate full-size attention computations would mean 32× the work. The trick is to not run them at full size.
TinyLlama's hidden dimension is 2,048. It uses 32 heads. Rather than giving each head the full 2,048 dimensions, the model divides the space into 32 slices of 64 dimensions each:
d_model = 2,048
n_heads = 32
d_head = d_model / n_heads = 64
In practice, you still compute Q, K, V at full width once:
Q = input × W_Q (2,048 → 2,048)
K = input × W_K (2,048 → 2,048)
V = input × W_V (2,048 → 2,048)
So we have Q, K, V, each a 2,048-dimensional vector per token. Now slice each of them into 32 pieces of 64 dimensions. Piece 1 gets dimensions 1 through 64. Piece 2 gets dimensions 65 through 128. And so on. Each piece is one head's workspace.
Now run the full attention computation (the , softmax, multiply by formula from Detail 5) independently on each 64-dimensional piece. Head 1 runs attention using only its slice of Q, K, V. Head 2 does the same with its slice. All 32 heads operate in parallel, each on its own 64-dimensional subspace. This is also where the scaling becomes concrete: each head operates in 64 dimensions, so and the scale factor is .
The total compute is roughly the same as one full-size attention operation. We haven't added work; we've reorganized it into 32 parallel smaller pieces.
Each head produces a 64-dimensional output per token. To get back to the model's full width, concatenate all 32 outputs end-to-end: 32 × 64 = 2,048. Then one final learned projection () "mixes" information across heads, letting the model combine what different heads discovered into a single output vector.
What Heads Learn
Because each head has its own projections, different heads can specialize. Researchers have observed:
- Local heads that mostly attend to the previous few tokens
- Syntactic heads that latch onto subject-verb pairs
- Delimiter heads that jump to quotes, parentheses, or newlines
- "Sink" heads that attend heavily to the first token regardless of content
Nothing forces a head into any particular role. Some heads end up redundant. The point is capacity: the layer is no longer limited to a single attention pattern. Going back to our "landed" example, one head can put most of its weight on "bat" while another head independently puts its weight on "flew" and a third on "cave." Three questions, three separate answers, combined at the end.
There is a cost to this. Every head maintains its own Keys and Values, and (as we'll see when we get to the KV cache) we store those for every past token. 32 heads means 32 sets of K and V per layer per position. The next section covers variants that reduce this cost.
Detail 7: Attention Variants (GQA, SWA, MLA)
Multi-head attention gave us better representations at the cost of storing more state. That state adds up fast.
The Storage Problem
At the end of Detail 6, we noted that every head maintains its own Keys and Values. Let's think about what that means during generation.
Recall from Detail 4: the model generates one token at a time. To produce token t, attention needs to compare the new token's Query against the Keys of every previous token 0..t-1, and then take a weighted sum of their Values. That's what attention does: match against Keys, collect from Values.
Now, do we recompute every previous token's K and V from scratch on every step? We could, but it would be wasteful. Those tokens haven't changed. Token 3's Key is the same whether we're generating token 50 or token 500. So inference engines just store them. Each time we process a new token, we compute its K and V, add them to the stored collection, and reuse the rest. This stored collection of past Keys and Values is called the KV cache. (We'll formalize it properly in Detail 10. For now, just know it exists and grows by one entry per token generated.)
Let's see where multi-head attention gets expensive. With 32 heads, each maintaining its own K and V, we're storing 32 separate K vectors and 32 separate V vectors for every token, at every layer.
Let's put numbers to this. If TinyLlama used standard multi-head attention (it doesn't, and you'll soon see why):
Per position, per layer:
32 heads × 64 dims × 2 (one K, one V) = 4,096 values
Across all 22 layers, at FP16 (2 bytes each):
4,096 × 22 × 2 bytes = 180 KB per token position
At max context (2,048 tokens):
2,048 × 180 KB ≈ 360 MB per sequence
360 MB just to remember context for a single sequence. And this is TinyLlama, a 1.1B model. For a 70B model, the KV cache can reach gigabytes per sequence. Serve 50 users at once and the cache alone could exhaust your GPU's memory.
The storage cost scales directly with the number of heads. 32 heads means 32 copies of K and V at every position and every layer. The question is whether we actually need all 32.
Do We Really Need 32 Copies?
Let's go back to the three roles from Detail 5. Each token gets transformed into:
- Q (Query): what this head is looking for
- K (Key): how this token advertises itself for matching
- V (Value): what this token contributes when selected
In the "landed" example from Detail 6, three heads asked three different questions: "who is the subject?", "what action preceded this?", "where did it come from?" Those are different Queries. Each head needs its own Q to ask its own question.
But think about what K and V represent. They describe the context tokens, the sentence the heads are searching over. And it's the same sentence regardless of what question you're asking. One head looks for the subject, another for the verb, another for the location. Different questions, same pool of tokens to search.
Imagine 32 analysts, each researching a different question about the same company. Do they each need their own copy of the company's financial filings? No. They can all search the same documents. What matters is that each analyst has their own research question. The documents don't change depending on who's reading them.
Q and K/V serve different roles, and they don't have to scale together. You can have many Query heads (many different questions) while sharing fewer sets of Keys and Values (fewer copies of the documents to search).
The Extreme: One Shared K/V
The most aggressive version of this idea: all 32 query heads share a single Key and a single Value. Each head still has its own , so each head can still attend to different tokens. But they all look into the same K/V.
This is called Multi-Query Attention (MQA).
MQA for TinyLlama:
1 K + 1 V × 64 dims = 128 values per position per layer
128 × 22 layers × 2 bytes = 5.6 KB per token position
At 2,048 tokens: ~11.5 MB per sequence
Compared to 360 MB with standard multi-head attention, that's a 32× reduction. The savings come from a simple accounting change: instead of 32 K/V sets, we store 1.
But there's a cost. All heads now search the same representation of context. Think back to the analyst analogy. With 32 separate copies, each analyst could have their filings organized differently: one sorted by date, another by department, another by transaction size. Each organization (each K/V representation) is optimized for a different kind of question. With MQA, all 32 analysts share one filing system. If one analyst's question would benefit from a different organization, too bad.
In practice, MQA can lose quality. For some models and tasks it works fine. For others, collapsing 32 specialized K/V representations into one generic one hurts. The question becomes: is there a middle ground?
The Middle Ground: Grouped-Query Attention (GQA)
Instead of jumping from 32 K/V sets all the way down to 1, what if we used a small number somewhere in between?
Grouped-Query Attention does exactly this. Keep a handful of K/V sets, fewer than 32 but more than 1, and divide the query heads into groups. Each group shares K/V.
TinyLlama uses GQA with 4 KV groups:
- 32 query heads ÷ 4 KV groups = 8 query heads per group
- Within each group, 8 heads share the same K and V
- Across groups, the K/V representations are different
GQA-4 for TinyLlama:
4 groups × 64 dims × 2 (K+V) = 512 values per position per layer
512 × 22 layers × 2 bytes = 22 KB per token position
At 2,048 tokens: ~45 MB per sequence
8× smaller than standard MHA. Not as extreme as MQA's 32× reduction, but 4 K/V groups can still specialize. One group might learn a K/V representation tuned for syntactic relationships, another for semantic ones, another for positional patterns. You lose some diversity compared to 32 separate sets, but you keep the most important distinctions.
Most modern models land here. Llama 2 70B, Llama 3, Gemma, and TinyLlama all use GQA.
All three of these (standard MHA, GQA, MQA) are really points on the same spectrum, controlled by a single number: n_kv_heads, the count of distinct K/V sets.
| Variant | n_kv_heads | Cache per position (TinyLlama) |
|---|---|---|
| MHA | 32 (= n_heads) | 180 KB |
| GQA-4 | 4 | 22 KB |
| MQA | 1 | 5.6 KB |
MHA is one end. MQA is the other. GQA is everything in between. This is an architecture choice baked in during training. By the time you're serving the model, you're living with whatever was chosen.
A Different Lever: How Far Back to Look
GQA/MQA reduce how much you store per token position. But there's another question entirely: how many past positions do you need to store at all?
Standard attention lets every token attend to all previous tokens. Token 5,000 can look all the way back to token 1. But does it actually need to?
Think about how you read. Right now, understanding this sentence mostly depends on the last few sentences. You need the context of this section, the current paragraph, maybe the heading. You don't need to re-read the opening of the blog post to parse this sentence. Occasionally you do look far back (a term defined much earlier, a running example established at the start), but for most of what you read, recent context dominates.
Sliding Window Attention (SWA) formalizes this observation. Instead of attending to all previous tokens, each token only attends to the last W tokens (the "window"), plus itself. Everything older is masked out, as if it doesn't exist.
If W = 4,096 and the sequence has reached 10,000 tokens, token 5,000 can attend to tokens 905 through 5,000. Token 1 is invisible to it.
The trade-off is direct:
- Smaller window = less memory. The KV cache only needs to hold the last
Wtokens' worth of K/V, regardless of how long the sequence gets. Memory stays bounded. - Smaller window = less long-range access. If critical context lives 5,000 tokens back and your window is 4,096, you can't directly reach it.
But there's a subtlety. Information can still propagate beyond the window size by flowing through intermediate layers. Token 5,000 can't attend to token 1 directly, but if token 3,000 attended to token 1 in an earlier layer, and token 5,000 attends to token 3,000, the information has traveled indirectly. It's lossy and imprecise, but not completely lost.
In practice, many models use SWA in some layers and full attention in others. Gemma 3, for example, uses a 5:1 ratio: five sliding-window layers for every one full-attention layer. The sliding-window layers handle local relationships cheaply. The occasional full-attention layer provides a direct path for long-range information. This hybrid approach gets most of the memory savings while keeping long-range capability.
Notice that this is a fundamentally different lever than GQA. GQA reduces the size of what you store per token position (fewer K/V sets). SWA reduces the number of token positions you store. They're orthogonal, and a model can use both.
Yet Another Lever: Compress What You Store
GQA shares K/V across heads: fewer sets, same size each. SWA limits how many past tokens you keep: fewer positions, same K/V per position. There's a third option: keep all the positions and all the heads, but make each K/V entry smaller.
Multi-Head Latent Attention (MLA), used in DeepSeek V2 and V3, takes this approach. Instead of caching the full K and V vectors for each token, it compresses them into a smaller "latent" vector. When attention needs the K and V for a past token, it decompresses them on the fly from the stored latent.
Think of the difference between storing full-resolution photos and storing compressed thumbnails. The thumbnails take far less space in the filing cabinet. When you actually need to examine a photo, you reconstruct the full resolution from the compressed version plus some stored metadata. The reconstruction costs some compute, but the storage savings can be dramatic.
Concretely, if the latent dimension is half the original K/V dimension, the cache shrinks by roughly half. In DeepSeek V2's configuration, the latent dimension is 512 versus the original K/V dimension of 2,048 across all heads, a 4× compression. The KV cache shrinks by roughly 75%.
The interesting part is what happens to quality. GQA saves memory by making heads share identical K/V vectors. That's a hard constraint: heads that would benefit from seeing different K/V representations are forced to see the same one. MLA does something different. Each head still gets its own K and V after decompression. The shared latent is a compressed joint representation, not a one-size-fits-all compromise. Each head applies its own learned decompression matrix to the latent, so the K/V it sees is head-specific. You store one small latent per position, but reconstruct head-specific K/V from it.
The result is that MLA comfortably beats GQA on quality in DeepSeek V2's ablations, despite using comparable cache memory. That makes sense: GQA forces heads to share identical K/V, while MLA reconstructs head-specific K/V from the latent. The per-head specialization that GQA sacrifices is exactly what MLA preserves. Against standard MHA (where each head already has its own full, uncompressed K/V), MLA won't have an inherent quality advantage since you can't beat uncompressed by compressing. But it gets close, at a fraction of the memory cost. The trade-off is compute: decompression adds a matrix multiply per layer per token during attention. Whether that's worth it depends on whether you're more constrained by memory or by compute, and during decode (don't worry if you don't know what decode is, we will cover that in a later section), you're almost always more constrained by memory.
The trade-off is the mirror of GQA's:
- GQA saves memory by sharing K/V across heads (fewer copies, full size)
- MLA saves memory by compressing K/V per position (all copies, smaller size)
- GQA sacrifices some per-head specialization; MLA preserves it but spends extra compute on decompression
The Three Levers
Every attention variant we've seen pulls one of three levers to reduce KV cache memory:
KV cache memory ∝ n_layers × seq_len × n_kv_heads × d_head × 2 (K,V)
| Lever | What it targets | Examples |
|---|---|---|
| Share K/V across heads | ↓ n_kv_heads | GQA, MQA |
| Limit attention range | ↓ effective seq_len | SWA |
| Compress K/V entries | ↓ effective d_head | MLA |
These levers are orthogonal. A model can combine them: GQA to share K/V across heads, SWA on most layers to cap context length, and full attention on a few layers for long-range access. Real models do exactly this.
The payoff is concrete. For a Llama 2 70B-like configuration at 2,048 tokens:
- Standard MHA: ~5.4 GB per sequence
- With GQA-8 (what it actually uses): ~0.67 GB per sequence
That 8× reduction is the difference between serving 1 user and serving 8 on the same GPU. These are architecture choices baked in at training time, but understanding them explains why some models are cheap to serve and others are expensive.
We've now covered the attention side of the transformer block and the levers that control its memory cost. But attention is only half of each block. Each block also has a simpler component that processes tokens individually.
Detail 8: Feed-Forward Network (FFN)
After attention, each token's vector has been enriched with context from the rest of the sequence. "bat" is no longer the generic dictionary entry; it's been pulled toward "flying animal" by attending to "flew." But the vector is still just a blend. A weighted average of the raw inputs.
Think about what attention actually did. It gathered relevant information. "bat" looked around, found "flew" was relevant, and mixed some of "flew" into its own representation. That's useful, but gathering isn't the same as processing. Attention answered "what context is relevant?" It didn't answer "given this context, what should I conclude?"
Go back to the hiring analogy. Attention was the panel of interviewers gathering signal from candidates. Each interviewer asked their question and collected weighted impressions. But nobody has made a decision yet. Nobody has synthesized "strong infra skills + weak communication + deep ML intuition" into "this person would be great for the platform team but not the customer-facing role." The raw signal has been collected. Now someone needs to sit down and think about what it means.
That's what the FFN does. It's the per-token processing step that follows attention. Where attention mixes information across tokens (communication), the FFN processes information within each token (computation). It takes the blended vector that attention produced and transforms it through a small neural network, independently at each position.
The distinction matters. Attention is wide but shallow: it can pull from anywhere in the sequence, but all it does is compute weighted averages. Linear combinations. It can't, on its own, compute anything nonlinear. It can't decide "this combination of features means X." The FFN is where nonlinearity lives. It's where the model applies learned functions to the gathered context: combining features, suppressing irrelevant ones, amplifying important patterns, and transforming the representation into something the next layer can build on.
A useful way to think about it: attention diffuses information across the sequence. The FFN is where the model pauses and actually thinks about what attention just collected.
This happens independently at every token position. No token-to-token connections. Same weights, applied to every row of the [seq_len × d_model] matrix. Each token gets the same neural network applied to its own (now context-enriched) vector.
If you want the MLP story from scratch, I wrote a separate deep dive: Multi-Layer Perceptrons: How Neural Networks Bend Space to See. Here we only need the transformer version: a position-wise MLP reused at every token position.
What It Looks Like
Structurally, the FFN is simple. Three matrix multiplications with a nonlinearity in the middle:
- Expand: Project the 2,048-dimensional vector up to a larger space (5,632 dimensions in TinyLlama)
- Transform: Apply a nonlinear function. This is the step that lets the network compute things that weighted averages can't.
- Contract: Project back down to 2,048 dimensions
Why expand first? For the same reason you might spread out parts on a large workbench before assembling them. In the larger space, features that overlap in 2,048 dimensions get separated. The network can manipulate them independently, then compress the result back down. More room in the intermediate step means the network can represent more complex functions of its input.
Where the Parameters Live
The FFN is where most of the model's weight lives. For each of TinyLlama's 22 layers:
FFN parameters per layer:
Three matrices of ~2,048 × 5,632 each ≈ 34.5M parameters
Across all 22 layers: 34.5M × 22 ≈ 759M parameters
That's roughly 2/3 of TinyLlama's total 1.1B parameters. Attention gets most of the conceptual spotlight, but the FFN is where most of the model's learned knowledge actually lives. The attention mechanism is comparatively parameter-light; its job is routing. The FFN is where the model stores and applies what it learned during training.
Now we know both pieces inside a transformer block. How do they wire together?
Detail 9: The Complete Transformer Block
In Detail 5 we opened the "Processing" black box and found a stack of transformer blocks. Then we zoomed into one block and unpacked its two components: self-attention (Details 5, 6, 7) and the FFN (Detail 8).
Here's that block as it stood before the detour, still a black box:
We now know what's inside. The block has two components: self-attention (where tokens communicate) and the FFN (where each token is processed independently). Data flows top to bottom: attention first, then FFN.
But wiring them together is not just "run attention, then run FFN". Two practical problems show up when you stack these blocks 22 layers deep.
Problem 1: Every Layer Becomes a Rewrite
Each sublayer (attention or FFN) takes in a vector and produces a new one. If each sublayer simply replaces its input with its output, the first layer's transformation gets overwritten by the second, which gets overwritten by the third. By layer 22, any nuance from the early layers is gone.
The model also can't make small adjustments. If layer 15 needs a subtle refinement, it has to produce an entirely new 2,048-dimensional vector and hope the next 7 layers don't destroy what it just did.
The fix: instead of replacing the input, add to it.
output = input + sublayer(input)
The sublayer only needs to learn the change from the input, not an entirely new representation. The original signal passes through untouched, with each layer contributing a small delta. After 22 layers, the final vector is the original input plus 22 learned adjustments. Nothing gets overwritten. Early-layer features survive to the end.
This is called a residual connection (or skip connection). It also helps during training: gradients can flow directly through the addition, bypassing the sublayer entirely if needed, which prevents the vanishing gradient problem in deep networks.
Problem 2: Values Drift Out of Range
Each residual connection adds to the running vector. After several additions, values can grow large. Or small adjustments might push some dimensions toward extreme values while others stay near zero. The next sublayer expects inputs in a reasonable range. If one dimension is sitting at 500 and another at 0.01, the matrix multiplications inside that sublayer will behave poorly.
We need to recenter and rescale before each sublayer. Take the vector, compute its mean and variance across dimensions, subtract the mean, divide by the standard deviation. The result has zero mean and unit variance. Every sublayer sees well-behaved inputs regardless of what accumulated before it.
This is layer normalization. In practice it also has learned scale and bias parameters, but the stabilizing effect is the point. TinyLlama applies it before each sublayer (attention and FFN), a convention called "pre-norm" that's standard in modern transformers.
Putting One Block Together
With these two pieces in place, here is the data flow through a single transformer block:
- Input arrives (2,048-dim vector per token)
- Normalize → Self-attention → Add residual (input + attention output)
- Normalize → FFN → Add residual (previous + FFN output)
- Output passes to the next block
Two sublayers, two residual connections, two normalizations. The original input is always present in the running sum, with attention and the FFN adding updates on top.
The Full Stack
This block repeats 22 times in TinyLlama. Block 1's output feeds into block 2, and so on. Each block has different learned weights, so each layer can specialize in different kinds of transformations, but the structure is identical. After the final block, the output vectors enter the LM head (from Detail 2b) for next-token prediction.
We've now opened the entire pipeline. Tokenizer → embedding → positional encoding → 22 transformer blocks (each with attention + FFN, wrapped in residuals and norms) → LM head → sampling → detokenizer. Every component has been explained.
Now, recall from Detail 4 how each autoregressive step reprocesses the entire sequence. With attention's mechanics fresh in mind, you can see exactly what's being recomputed: the Q, K, and V vectors for every previous token, even though they haven't changed. That redundancy is the target of the next section.
Detail 10: The KV Cache
At the end of Detail 9, we had the full transformer pipeline. At the end of Detail 4, we had a compute problem: each autoregressive step re-runs the model on the whole prefix.
With the internals open, we can now point at the exact redundancy.
At generation step t, attention needs:
- the current token's Query at each layer
- all previous tokens' Keys and Values at each layer
When we move from step t to step t+1, previous tokens have not changed. So their Keys and Values have not changed either. Recomputing them is pure duplicate work.
The fix is simple. Compute each token's K and V once, keep them in GPU memory, and reuse them on later steps. On each new decode step, compute K and V only for the newest token and append them.
Attention math stays identical. The current Query still matches against all past Keys, and the weighted sum still pulls from all past Values. The only change is where old K/V come from: cache read instead of recompute.
Why not cache Q too?
Because past Queries are never needed again. At step t, only token t is querying context. At step t+1, token t+1 forms its own Query. Old Queries are one-time intermediates.
This persistent store of past Keys and Values is the KV cache.
What Changes in the Pipeline
With KV cache, the pipeline now has two modes:
- Prefill: Process the full prompt in parallel, and write K/V for all prompt positions into cache.
- Decode: Process one token at a time, read old K/V from cache, append K/V for the new token.
So the model still generates one token per step, but each step now reuses nearly all past attention state.
How Much Memory Does It Take?
The cache stores K and V for every generated position at every layer. Memory grows linearly with sequence length.
TinyLlama uses GQA with 4 KV heads, each 64 dimensions. So per position, per layer:
Per position, per layer:
K: 4 heads × 64 dims = 256 values
V: 4 heads × 64 dims = 256 values
Total: 512 values
At FP16 (2 bytes per value):
Per layer: 512 × 2 bytes = 1 KB per position
All 22 layers: 1 KB × 22 = 22 KB per position
For max context (2,048 tokens):
2,048 × 22 KB ≈ 45 MB per sequence
45 MB per sequence, on top of the 2.2 GB for model weights. That's manageable for one sequence. But with batching (multiple sequences running simultaneously), it adds up. Ten concurrent sequences need 450 MB of cache. For reference, if TinyLlama used standard Multi-Head Attention with 32 KV heads instead of GQA with 4, the cache would be 8× larger: ~360 MB per sequence. The GQA choice from Detail 7 is already paying off here.
We now have the tool that eliminates redundant computation. But it raises a question: prefill processes everything at once, decode processes one token at a time. They behave differently on GPU hardware. That's the next detail.
Detail 11: Prefill vs Decode
In Detail 4, we saw the autoregressive loop: the model processes the full sequence on every step, with only the last token being new. In Detail 10, we fixed the waste with the KV cache: compute each token's K and V once, store them, never recompute.
This fix naturally splits inference into two distinct phases. To see why, go back to our running example.
Two Different Jobs
The user types "Write a story" and we want the model to generate "Once upon a time..."
The prompt tokens [8144, 264, 3446] are all known before the model runs. We have the complete input sitting there, waiting. We can process all three tokens at once: run them through all 22 transformer blocks in parallel, compute their K and V vectors, and fill the entire KV cache in a single pass. At the end of this pass, the model produces logits for the last position and we sample the first output token: "Once."
Now generation begins. We need the next token. The model processes "Once" through all 22 layers (reading the cached K/V from the prompt tokens for attention), appends its K/V to the cache, and samples "upon." Then it processes "upon," appends, samples "a." Then "time." Each output token depends on the one before it. We can't produce "upon" until "Once" exists. We can't produce "time" until "a" exists.
Notice the asymmetry. The prompt was fully known, so we processed it all at once. The output is built token by token, because each token depends on the one before it. These are two fundamentally different jobs:
- Prefill: Process all prompt tokens in one parallel pass. Fill the KV cache. Emit the first output token.
- Decode: Generate output tokens one at a time. Each step reads the cache, processes one new token, appends its K/V, and samples.
Same model, same weights, same attention formula. But the shape of the work is completely different.
What This Looks Like on Hardware
Scale up from 3 prompt tokens to something realistic:
- Prompt: 1,000 tokens
- Desired output: 200 tokens
During prefill, the GPU processes all 1,000 prompt tokens together. Each layer performs matrix multiplications on a [1,000 × 2,048] input matrix. Attention builds a 1,000×1,000 score matrix per head, then multiplies through Values. That's billions of multiply-add operations per layer, across 22 layers.
The model weights (2.2 GB) get read from memory once, but every byte gets reused across all 1,000 token positions. The GPU's compute units stay busy because there's enough parallel work to keep them fed. The limiting factor is how fast the GPU can do math: there's so much computation per byte of data moved that the GPU spends most of its time computing, not waiting for data. This pattern is called compute-bound.
During decode, each of the 200 steps processes exactly one new token. The GPU reads the same 2.2 GB of weights from memory, but now uses them for a [1 × 2,048] times [2,048 × 2,048] multiplication. About 4 million multiply-add operations. The math finishes in microseconds. Then the GPU sits idle, waiting for the next chunk of weights to arrive from VRAM.
During prefill, 2.2 GB of weights served 1,000 tokens. During decode, those same 2.2 GB serve one. Same data movement. 1,000× less useful work. The limiting factor flips: the GPU now spends most of its time waiting for data, not computing. The compute units are mostly idle, starving for bytes. This pattern is called memory-bandwidth-bound.
The User-Visible Metrics
These two phases map directly to what you experience as a user.
When you send a prompt and wait for the response to start appearing, that pause is prefill running. The time it takes is called Time-to-First-Token (TTFT). It depends on prompt length and GPU compute speed.
Once tokens start streaming, the pace at which they appear is determined by decode. The time between consecutive tokens is called Inter-Token Latency (ITL). It depends on model size and GPU memory bandwidth.
| Phase | Tokens per pass | Bottleneck | User-visible metric |
|---|---|---|---|
| Prefill | 1,000 (parallel) | Compute | TTFT |
| Decode | 1 (sequential) | Memory bandwidth | ITL |
Decode Dominates
The total time to generate a response:
Total latency = TTFT + (ITL × output_tokens)
For our example: 1 prefill pass + 200 decode steps. Even if prefill takes 10× longer than a single decode step, decode still dominates the total: 200 steps add up.
Most wall-clock time is spent in decode, not prefill. Optimizing inference means optimizing decode.
The next detail puts concrete bandwidth numbers to this picture.
Detail 12: Memory Bandwidth is the Bottleneck
In Detail 11, we said decode is memory-bound. Let's put numbers to that claim.
One Decode Step, Count the Bytes
Take one decode step with batch size 1 (generate exactly one next token).
At that step, TinyLlama has to:
- Read model weights: ~2.2 GB (FP16)
- Read KV cache for the current context
- Run the math for this one token
- Write back logits and new KV entries
Look at the proportions. Step 1 reads 2.2 GB. Step 3, the actual math for a single token vector, completes in microseconds. Loading the weights takes milliseconds. The GPU spends almost all its time on data movement, not computation.
And this repeats every decode step. The GPU's on-chip memory (SRAM) is far too small to hold the full model, so it streams the same 2.2 GB from VRAM each time. Token 1: read 2.2 GB. Token 2: read 2.2 GB again. Token 200: read 2.2 GB again. Same weights, same cost, every step.
Bandwidth Sets a Hard Latency Floor
If a GPU can move B GB/s from VRAM, then reading 2.2 GB cannot be faster than 2.2 / B seconds.
Using peak bandwidth numbers:
| GPU | Peak memory bandwidth | Minimum time to read 2.2 GB |
|---|---|---|
| RTX 4090 | 1,008 GB/s | 2.18 ms |
| A100 40GB | 1,555 GB/s | 1.41 ms |
| H100 | 3,350 GB/s | 0.66 ms |
These are lower bounds from peak specs. Real ITL is usually higher due to kernel overheads, imperfect memory access patterns, scheduling gaps, and KV-cache traffic.
So even before talking about compute, decode already has a millisecond-scale floor per token. That floor comes from memory movement.
How Much Computation Per Byte?
We've shown that decode reads 2.2 GB per token and that the math finishes almost instantly. A useful question follows: for each byte the GPU loads from memory, how many operations does it actually perform?
For decode (batch size 1), rough order of magnitude:
- FLOPs per token: ~2.2 billion
- Bytes moved: ~2.2 GB weights + KV traffic
- Ratio: ~1 FLOP per byte
One floating-point operation per byte loaded. The GPU does one multiply-add, then waits for the next byte to arrive.
For prefill (1,000-token prompt), the same weights serve 1,000 token positions:
- FLOPs: ~2.2 trillion (1,000× more work)
- Bytes moved: still ~2.2 GB (same weights, read once)
- Ratio: ~500-1,000 FLOPs per byte
Same data movement. Orders of magnitude more useful computation.
This ratio of computation to data movement has a name: arithmetic intensity.
Modern GPUs need roughly 100-300 FLOPs/byte to fully utilize their compute units. Decode at ~1 FLOP/byte sits two orders of magnitude below that threshold. The compute cores are mostly idle, starving for data. Prefill at ~500-1,000 FLOPs/byte sits comfortably above it, which is why it's compute-bound.
Why Batching Matters
Batching improves this by reusing each weight read across multiple sequences.
With batch size 1:
- Read ~2.2 GB weights
- Produce 1 token
- Arithmetic intensity: ~1 FLOP/byte
With batch size 32:
- Read ~2.2 GB weights (same weight bytes)
- Produce 32 tokens (one per sequence)
- Arithmetic intensity: ~32 FLOPs/byte
The model weights do not change with batch size. What changes is how much useful work you squeeze out of each byte loaded from memory. That is why production inference systems batch aggressively. We'll see how to do that without hurting latency in Part 2.
Part 1 Summary: The Complete Inference Pipeline
We started with a black box. Twelve details later, every component between input text and output text has a name and a mechanism. You can now trace a prompt end to end:
Tracing "Write a story" End to End
- Tokenize: "Write a story" → [8144, 264, 3446]
- Embed and add position: each ID becomes a 2,048-dimensional vector, then positional encoding injects order
- Run the transformer stack: 22 blocks: attention (mix across positions) then FFN (per-position compute)
- Project and sample: LM head → 32,000 logits → probabilities → sample the next token ("Once", for example)
The first token comes from prefill: process the full prompt in parallel, populate the KV cache, emit token 1.
Everything after that is decode. Each step adds one token, reuses the cached K/V from previous tokens, appends the new K/V, and runs the same stack again to pick the next token.
The key asymmetry: at batch size 1, decode rereads essentially all model weights from VRAM to produce one token. The arithmetic finishes quickly. The bytes set the pace.
The Numbers
For TinyLlama serving one sequence:
Model weights (FP16): 2.2 GB
KV cache (2,048 context): ~45 MB (GQA keeps this small)
Activations: ~50-100 MB
Total: ~2.4 GB
The two user-visible metrics map directly to the phases. TTFT (Time to First Token) is how long prefill takes: compute-bound, scales with prompt length. ITL (Inter-Token Latency) is how long each decode step takes: bandwidth-bound, scales with model size and GPU memory bandwidth.
Total latency = TTFT + (ITL × output tokens). For most responses, decode dominates.
If you keep one systems-level idea from Part 1, make it this: decode is a bandwidth problem. Part 2 is mostly different ways of moving fewer bytes, wasting fewer bytes, or sharing those bytes across more useful work.
The Three Bottlenecks
You now understand how inference works. But you may have noticed some problems:
- Large memory footprint: Model weights (2.2 GB for TinyLlama) plus KV cache for every sequence. Memory fills up fast.
- Quadratic attention complexity: Attention scores scale with sequence length squared. Longer contexts = much more computation.
- Sequential generation: Tokens come out one at a time. Each requires a full forward pass through all 1.1B parameters.
Every production optimization addresses one or more of these bottlenecks. Let's see how.
Part 2: Production Optimizations
The following sections cover techniques used by vLLM, TGI, and other inference servers. These are optional if your goal is just conceptual understanding. Part 1 covers the fundamentals.
Detail 13: Batching Requests
In Detail 12, we saw that each decode step reads all 2.2 GB of model weights from VRAM to produce one token. Most of the time is spent on data transfer, not computation. Batching fixes this: process multiple sequences per step, read the weights once, produce B tokens instead of one. The GPU does more useful work per byte loaded.
But that explanation assumed you already had a batch ready to go. In production, requests arrive at unpredictable times and need unpredictable numbers of tokens. How do you actually form batches?
The Simplest Approach
Collect N requests into a group. Run prefill for all of them. Then decode in lockstep: each step reads the weights once and generates one token per sequence. When every sequence in the group has finished, return all results.
This works. With batch size 32, the GPU reads roughly the same 2.2 GB of weights per step but produces 32 tokens instead of 1. Same data movement, 32x the output. Throughput scales nearly linearly with batch size until KV cache memory becomes the limiting factor.
Where It Breaks
The scheme above works perfectly if all requests arrive together and generate the same number of tokens. In practice, neither is true.
Three requests arrive at the same time:
Request A: "Write a poem" → generates 50 tokens
Request B: "Explain gravity" → generates 120 tokens
Request C: "Hello" → generates 20 tokens
You batch them and start decoding. For the first 20 steps, all three slots do useful work. Then Request C finishes. Its result goes back to the user, but its slot sits empty. The batch keeps running until Request B finishes at step 120, so that slot is padding for 100 steps: processed on every forward pass, producing nothing useful.
Request A finishes at step 50. Now two of three slots are padding for the remaining 70 steps.
A new request arriving at step 30 can't join. The batch composition was fixed at step 0. It waits in a queue until step 120, when the entire batch completes and a new one can form.
There is also a cost before decoding even begins. You need N requests to form a batch of N. Under low traffic, you wait for them to accumulate (or use a timeout, which produces smaller, less efficient batches). Either way, users experience added latency from the queueing alone.
This is static batching: the batch composition is fixed from start to finish. No one leaves early, no one joins late.
Static batching improves throughput over serving one request at a time. But it has three problems: finished sequences waste GPU cycles as padding, new requests wait for the entire batch to complete, and short requests are held hostage by long ones. For an interactive service where users are watching tokens stream in, all three matter.
Detail 14: Continuous (In-Flight) Batching
Detail 13 ended with two problems. When Request C finishes at step 20, its slot sits empty until step 120. And any new request arriving at step 30 has to wait for the entire batch to finish.
Both problems come from the same rigidity: the batch composition is fixed from start to finish. Once you form the batch, no one leaves and no one joins.
But there's no computational reason for this. Each decode step reads the model weights once and produces one token per occupied slot. The slots are independent. Request A's token generation doesn't depend on Request C being present. So when C finishes at step 20, we could just remove it and keep decoding with A and B.
And if we can remove finished sequences, we can also add new ones. A free slot is a free slot. If Request D is waiting in the queue and slot 3 just opened, put D there.
Walk through the same three requests with this approach:
Step 0: [A, B, C] all three decoding
Step 20: [A, B, _] C finishes, returned to user. Slot freed.
Step 21: [A, B, D] new request D takes the open slot
Step 50: [_, B, D] A finishes, returned. Slot freed.
Step 51: [E, B, D] request E admitted
...
C leaves at step 20 and its slot does useful work again by step 21. No padding. Request D didn't have to wait for the entire batch to finish; it got in as soon as a slot opened. The GPU stays close to full as long as there is queued work.
This is continuous batching (also called in-flight batching, because the batch composition changes while generation is in flight). Most production inference servers (vLLM, TGI, and others) implement some version of this.
What does "admit a new request" actually involve? The server runs prefill for that request's prompt, allocates KV cache for it, and inserts it into the decode loop. From that point on, it advances one token per step alongside everything else.
The scheduler running this is a tight loop:
- Check for finished sequences. Return their outputs, free their KV cache.
- Check for waiting requests. If KV cache memory allows, run prefill and admit them.
- Run one decode step for all active slots.
- Repeat.
Detail 15: The KV Cache Memory Problem
Detail 14's scheduler has a step that says "if KV cache memory allows, admit a new request." That check sounds simple. It is not.
Continuous batching keeps the GPU busy by running many sequences at once. But every active sequence needs its own KV cache. How many can we actually fit?
For TinyLlama at max context (2,048 tokens):
- KV cache per sequence: ~45 MB (with GQA)
- With 10 concurrent sequences: 450 MB
- Plus model weights: 2.2 GB
- Total: ~2.7 GB, before activations and allocator overhead
TinyLlama is relatively small, and its GQA architecture keeps KV cache modest. Scale up the model or the context length and KV cache quickly becomes multiple gigabytes per sequence. The question stops being "do we have a fast enough GPU" and becomes "how do we allocate memory for all these sequences at once?"
That turns out to be harder than it sounds. You need to decide how to carve up KV memory for each sequence, and the two obvious approaches are both broken.
Approach 1: Reserve for the worst case
The simplest thing to do: when a new sequence arrives, pre-allocate KV cache for the maximum context length. The sequence gets one big contiguous block. No reallocation, no copying as it grows. Clean and predictable.
But think about what this costs. Most requests don't use anywhere near max context. A 50-token request reserves 45 MB for a potential 2,048 tokens:
Allocated: 45 MB
Actually used: (50/2048) × 45 MB = 1.1 MB
Wasted: 43.9 MB (97.5%)
97.5% of that memory does nothing. Multiply across 10 concurrent sequences and you are wasting over 400 MB on space that no token will ever fill.
This is called internal fragmentation: the waste that lives inside each allocation, between what was reserved and what was actually used.
Approach 2: Grow as you go
The natural fix: don't pre-allocate. Start small and grow each sequence's KV allocation as it generates more tokens. Now you only use what you need. Problem solved?
Not quite. Sequences start at different times, grow at different rates, and finish at different times. As short sequences finish and free their memory, they leave gaps. Over time, free space scatters into holes between active allocations:
Memory layout after running for a while:
[Seq 1 KV][ free ][Seq 3 KV][Seq 4 KV][ free ][Seq 6 KV]...
Total free: 400 MB (scattered)
Largest contiguous: 150 MB
New sequence needs: 200 MB contiguous
Result: Can't add new sequence despite having "enough" memory.
This is external fragmentation: plenty of free memory in total, but no single chunk large enough for the next allocation. Free bytes are not enough. Layout matters.
This is why the scheduler's step 2 is more subtle than checking a single number. In real workloads, fragmentation wastes 60 to 80% of the KV cache budget.
Operating systems solved this for regular RAM by stopping the requirement that allocations be physically contiguous. The next detail applies the same idea to KV cache.
Detail 16: PagedAttention (vLLM's Innovation)
Both fragmentation problems from Detail 15 trace back to one assumption: each sequence's KV cache must be a single contiguous block in VRAM. Pre-allocate a big contiguous block and you waste the interior (internal fragmentation). Grow contiguous blocks dynamically and the gaps between them become unusable (external fragmentation). The root cause is not the size of the allocation or the timing of it. The root cause is the contiguity requirement itself.
Operating systems hit this exact problem decades ago with regular RAM. Their solution was to give each program the illusion of contiguous memory without requiring it physically. A program sees a flat range of virtual addresses, numbered 0 through however much memory it needs. But the OS maps fixed-size units called "pages" (typically 4 KB) to physical "frames" scattered anywhere in RAM. The mapping lives in a per-process page table. When a process needs more memory, the OS grabs any free frame, not necessarily one adjacent to the last allocation. Because every frame is the same size and interchangeable, external fragmentation vanishes. Internal fragmentation drops to at most one partially filled page per allocation.
The mapping to KV cache is direct. A sequence is the program. Its logical KV positions (token 0, token 1, ...) are the virtual addresses. Physical frames become fixed-size KV blocks in VRAM, say 16 tokens each. A per-sequence block table plays the role of the page table, translating logical block indices to physical block locations. Can this indirection work for attention without changing the math?
How It Works
Start with the same 50-token request under the old scheme. The system reserves one contiguous block of 2,048 KV slots. 50 slots fill with actual key-value data. The other 1,998 sit empty but locked, unavailable to any other sequence, for the entire lifetime of the request. That is the baseline: 97.5% waste.
Now run the same request through paging with a block size of 16 tokens. Tokens 0–15 fill block 0. Tokens 16–31 fill block 1. Tokens 32–47 fill block 2. Tokens 48–49 go into block 3, leaving 14 slots empty. Four blocks total. The system grabs them one at a time from a free pool, wherever they happen to be in VRAM. Block 0 lands at physical address 7. Block 1 at address 2. Block 2 at address 14. Block 3 at address 5. A per-sequence block table records the mapping:
Block table for Sequence A:
Logical 0 → Physical 7 (tokens 0-15)
Logical 1 → Physical 2 (tokens 16-31)
Logical 2 → Physical 14 (tokens 32-47)
Logical 3 → Physical 5 (tokens 48-49, 14 slots empty)
From the sequence's perspective, its KV cache looks like one contiguous stretch of 50 positions. Physically, it is four scattered chunks.
Block 3 has 14 empty slots. That is real waste, the same kind of internal fragmentation from Detail 15. But paging caps it at 15 slots per sequence (one block minus one token), compared to 1,998 under static pre-allocation. And external fragmentation is gone entirely: every block is the same size and interchangeable, so the swiss-cheese problem from Detail 15 cannot happen. Any free block anywhere in VRAM can satisfy the next allocation. Profiling of systems before paged KV showed that only 20–38% of reserved KV cache memory was actually used by tokens. With block-level paging, waste drops to under 4%.
Now decode begins. New tokens fill the empty slots in block 3 (physical address 5). After 14 more generated tokens, that block is full. The system grabs one more free block from the pool, say physical address 11, and adds a new entry to the block table: logical 4 → physical 11. No existing blocks move. No data is copied. Growth costs one block allocation per 16 generated tokens.
When the sequence finishes, physical blocks 7, 2, 14, 5, and 11 all return to the free pool immediately. Any future sequence can reuse them. There is no compaction step and no need to defragment, because the system never required contiguity in the first place.
The attention math itself is unchanged. To generate token t, we still need keys and values for positions 0 through t-1. The only difference is one level of indirection in the addressing. Say the attention kernel needs the Key for token 35. Which physical block holds it?
block_size = 16
logical_block = floor(35 / 16) # = 2
offset = 35 % 16 # = 3 (4th slot in that block)
physical = block_table[2] # = 14 (from our table above)
Token 35's Key lives at the 4th slot of physical block 14. The kernel reads it from there and runs exactly the same dot products and softmax as before. One extra lookup per position is the entire cost.
Memory utilization jumps from 20–38% to over 96%. The scheduler can fit far more concurrent sequences in the same VRAM. Benchmarks show 2–4× higher throughput from the same hardware. Not from faster math, but simply from fitting more requests into memory at once.
Back-of-the-Envelope (TinyLlama)
Recall that TinyLlama's max-context KV cache is ~45 MB per sequence. Suppose we have 50 active sequences with an average context length of 500 tokens.
Without paging:
- Reserved: 50 × 45 MB = 2.25 GB
- Actually used: 50 × (500/2048 × 45 MB) ≈ 550 MB
With paging, each sequence allocates only as many blocks as it needs. The slack is at most 15 tokens in the last block per sequence. At ~22 KB per token (45 MB / 2,048 tokens), worst-case slack is about 330 KB per sequence, or ~17 MB across all 50. Total allocation lands near 567 MB instead of 2.25 GB.
The freed ~1.7 GB is enough for dozens more concurrent sequences. The continuous batching scheduler from Detail 14 can say "yes" to incoming requests far more often.
This technique is called PagedAttention, and the vLLM project introduced it. The name says exactly what it is: paged memory management applied to the attention KV cache. The attention math is untouched. The innovation is entirely in the memory allocator, the same layer that operating systems redesigned decades ago for the same reasons.
We have made KV cache storage efficient. But there is a separate inefficiency in how attention is computed, particularly during prefill. That is next.
Detail 17: FlashAttention
PagedAttention optimized how we store the KV cache. FlashAttention optimizes how we compute attention.
In Detail 12, we counted bytes for decode and found the bottleneck was bandwidth: the GPU spends milliseconds reading model weights and microseconds doing math. The attention step during prefill has its own bandwidth problem, one that has nothing to do with model weights. To see where it comes from, we need a slightly clearer picture of GPU memory.
Two Memories
In Detail 12, we mentioned that the GPU's on-chip memory (SRAM) is too small to hold the model. That was enough context for the bandwidth story. Here we need to look more closely.
A GPU has two memories that matter:
HBM (High Bandwidth Memory). This is what we have been calling "VRAM." On an A100, it holds 40 or 80 GB and delivers about 2 TB/s of bandwidth. Large, but physically distant from the compute cores, millimeters away on the package.
SRAM (Static RAM). Tiny. About 20 MB total on an A100. But roughly 10× faster than HBM, because it sits micrometers from the compute cores instead of millimeters.
Every GPU computation is a story of shuttling data between these two memories. Load operands from HBM into SRAM, compute, write results back to HBM. The closer you can keep data to the compute cores, the less time the GPU spends waiting.
Not all inference hardware is built this way, but GPUs dominate the landscape, and the rest of this section addresses their specific bottleneck.
Standard Attention: Where the Bytes Go
Now let's trace what happens when one attention head runs during prefill, and count how many bytes move between HBM and SRAM.
From Detail 6, TinyLlama splits attention into 32 heads, each working in 64 dimensions. Take a 2,048-token prompt. After the Q, K, V projections, each token has a 64-dimensional Query, Key, and Value vector. So Q is a matrix with one row per token and 64 columns:
Q: [2,048 × 64] (one 64-dim query per token)
K: [2,048 × 64] (one 64-dim key per token)
V: [2,048 × 64] (one 64-dim value per token)
How big are these in memory? 2,048 × 64 = 131,072 numbers. At FP16 (2 bytes per number): about 256 KB each. All three together: roughly 768 KB. Small.
Now run the attention formula from Detail 5. First step: compute the score matrix .
What shape is ? Q is [2,048 × 64]. K transposed is [64 × 2,048]. The matrix multiply gives [2,048 × 64] × [64 × 2,048] = [2,048 × 2,048].
Think about what each entry means. Entry is the dot product of token 's Query with token 's Key: one attention score for every pair of tokens. With 2,048 tokens, that is 4.2 million scores.
S: [2,048 × 2,048] = 4,194,304 numbers
At FP16: 4,194,304 × 2 bytes ≈ 8 MB
Q was 256 KB. S is 8 MB. The score matrix is 32× larger than the input that produced it. This is the scaling of attention: double the sequence length and S quadruples in size.
Second step: softmax converts S into attention weights P. Same shape [2,048 × 2,048], same size: 8 MB.
Third step: multiply P by V to get the output. [2,048 × 2,048] × [2,048 × 64] = [2,048 × 64]. Back to 256 KB. The computation balloons in the middle and shrinks back down.
Now here is the problem. On a GPU, each of these three steps runs as a separate operation (called a "kernel"). Separate kernels cannot pass results to each other through SRAM directly. They communicate through HBM: each kernel reads its inputs from HBM, computes, and writes its output back to HBM.
So the data flow for one attention head looks like this:
- Scores: Load Q (256 KB) and K (256 KB) from HBM into SRAM. Compute . Write S (8 MB) back to HBM.
- Softmax: Read S (8 MB) from HBM into SRAM. Compute P. Write P (8 MB) back to HBM.
- Output: Read P (8 MB) from HBM into SRAM. Load V (256 KB). Compute PV. Write output (256 KB) to HBM.
Every intermediate takes a round trip through HBM. S gets written once and read once: 16 MB of traffic. P gets written once and read once: another 16 MB. The actual inputs (Q, K, V) and output add about 1 MB. Total per head: roughly 33 MB of HBM traffic, of which about 32 MB is just S and P bouncing back and forth.
Scale up. TinyLlama has 32 heads per layer: 32 × 33 MB ≈ 1 GB per layer. Across all 22 layers: roughly 22 GB of HBM traffic just from attention intermediates. On an A100 at 2 TB/s bandwidth, that is about 11 ms of pure data transfer.
The GPU can compute these matrix multiplies in microseconds. Almost all the time goes to shuttling S and P through HBM. Attention during prefill is IO-bound, just like decode, but for a different reason: decode is bottlenecked on reading model weights (Detail 12), prefill attention is bottlenecked on intermediate matrices that exist only briefly but must pass through slow memory because that is how separate GPU kernels communicate.
The Fix: Never Build the Full Matrix
The [2,048 × 2,048] score matrix was 8 MB per head. It had to pass through HBM because each step (scores, softmax, output) ran as a separate kernel, and separate kernels can only communicate through HBM. What if we fused all three steps into a single kernel that never writes the scores to HBM at all?
Think about it row by row. Row of the score matrix holds token 's attention scores against all 2,048 keys. Once we have that row, we softmax it, multiply by V to get token 's output, and we are done with that row forever. We never need it again. So there is no reason to store all 2,048 rows at once. We can compute a few rows, use them immediately, and discard them before computing the next batch.
FlashAttention organizes this into tiles:
- Divide Q into blocks of rows (say, 128 rows at a time). Divide K and V into blocks of the same size.
- Load one Q block into SRAM: [128 × 64] = 16 KB. Fits easily.
- Stream K/V blocks through one at a time. For each K block, compute a small score tile: [128 × 128] = 32 KB. Apply softmax (with a correction we will discuss in a moment), multiply by the corresponding V block, and accumulate into a running output. All of this stays in SRAM.
- After all K/V blocks have streamed through, the output for these 128 query rows is complete. Write it to HBM.
- Move to the next Q block and repeat.
The full [2,048 × 2,048] matrix never exists. Not in HBM, not in SRAM, not anywhere. At any moment, only one [128 × 128] tile of scores sits in SRAM: 32 KB, roughly 250× smaller than the 8 MB score matrix it replaces.
The Softmax Obstacle
There is a catch. Softmax normalizes each row by the sum of exponentials across all keys:
When you only have one block of keys in SRAM, you do not have the full sum. You cannot normalize.
The naive fix is two passes: first scan all key blocks to compute the denominator, then scan again to compute the actual output. But that doubles the HBM traffic and defeats the purpose of tiling.
The real fix is to compute the normalized result incrementally, in a single pass, without ever needing the full sum upfront. The idea: maintain three running quantities per query row as key blocks stream through:
- The running maximum : the largest score seen so far, for numerical stability.
- The running sum of exponentials : , adjusted whenever the maximum changes.
- The partial output vector : the weighted sum of value vectors so far.
When a new block of keys arrives and introduces a new maximum , all previous exponentials shrink by the factor . Apply that correction to and with one multiply each, then incorporate the new block's scores normally. When the maximum does not change, the correction factor is 1 and the update is a simple accumulation.
The result is numerically identical to standard softmax. Not an approximation. The same answer, computed in one pass through the keys. This technique is called online softmax.
What This Buys You
Standard attention materializes the full [N × N] score matrix per head. For one head of TinyLlama at N = 2,048, that is 8 MB. Across 32 heads and 22 layers, the intermediates drive roughly 22 GB of HBM traffic per prefill pass. FlashAttention's working memory is just the SRAM tiles: a [128 × 128] score tile (32 KB) plus the running output. The 8 MB intermediate per head shrinks to tens of kilobytes.
The total FLOPs are the same. FlashAttention does not skip any computation. It just avoids writing and reading the big intermediate matrices. The speedup comes entirely from reduced HBM traffic. FlashAttention-2 achieves 50 to 73% of an A100's theoretical compute throughput, compared to about 25% for standard attention.
In practice, you rarely call FlashAttention directly. PyTorch's scaled_dot_product_attention selects a flash kernel automatically when hardware allows. Most production inference stacks use it under the hood.
One scope note. FlashAttention matters most during prefill, where the query length is and the intermediate is large. During decode, the query length is 1. There is no matrix to form. The decode bottleneck remains what Detail 12 identified: reading the model weights.
We have optimized how attention is computed and how KV cache memory is laid out. But what if the model weights themselves are simply too large? What if the GPU does not have enough memory to hold them at all?
Detail 18: Quantization
In Detail 12 we established a hard floor: each decode step reads the full weight set from VRAM. For TinyLlama that is 2.2 GB, every single token. The GPU spends most of its time just moving those bytes.
Now scale that up. A 70B-parameter model stored in FP16 is about 140 GB of weights. An A100 has 80 GB of VRAM. The weights alone do not fit on one GPU, let alone leave room for KV cache, activations, or a batch of requests. You either shard across multiple GPUs (expensive, complex) or you ask a different question: can we make the weights smaller?
Fewer Bits Per Weight
A neural network weight stored in FP16 uses 16 bits (2 bytes). The value might be something like 0.0217. But the model does not need all 16 bits of precision for inference. The weights just get multiplied and accumulated in matrix multiplies. So instead of storing the full float, you map it to a small integer (say, in the range -127 to 127) plus a single scale factor shared across a group of weights. The integer takes 8 bits. Half the storage, and the rounding error is small enough that the model's outputs barely change.
Here is a simple symmetric scheme:
scale = max(abs(W)) / 127
q = round(W / scale) # INT8 values in [-127, 127]
W_hat = q * scale # approximate weight for matmuls
Walk through one weight. Suppose max(abs(W)) for a group is 1.0, so scale = 1/127 ≈ 0.00787. The weight 0.0217 maps to round(0.0217 / 0.00787) = round(2.76) = 3. Reconstructed: 3 × 0.00787 = 0.0236. The reconstructed value is 0.0236 instead of 0.0217. A small rounding error, but the weight now takes 8 bits instead of 16.
INT4 is the same idea with a much smaller range (roughly -7 to 7). Fewer levels means coarser rounding, but each weight is only 4 bits, a quarter of the original.
In practice, you do not compute one scale for the entire weight matrix. There is a spectrum of granularity: per-tensor (one scale for the whole matrix), per-channel (one scale per output channel), and per-group (one scale per small block of weights, typically 64 to 128 values). Per-tensor is simplest but a single outlier stretches the range and wastes precision for everything else. Per-channel is better for weight matrices where magnitudes vary across output dimensions. Per-group gives the tightest precision: each small block gets its own scale tuned to its local range. Most INT4 methods (GPTQ, AWQ) use per-group quantization, which is a big reason they maintain reasonable quality at 4-bit compression.
This process of mapping floating-point weights to low-bit integers plus a scale factor is called quantization.
What Gets Quantized
When people say "INT4 model," they nearly always mean weight-only quantization. The model's fixed parameters (attention projections, FFN matrices, LM head) get compressed to low-bit integers. The activations, meaning everything computed during a forward pass as data flows through the network, stay in FP16 or BF16.
Why not quantize everything? Weights are the same numbers every forward pass. You can study their distribution, choose good scale factors, and be done. Activations are a different story. They change with every input. Worse, they spike: a handful of dimensions can hit values 100× larger than the rest. Force those into a low-bit integer range and the outliers consume all the dynamic range while everything else rounds to near-zero.
So the standard setup is INT4 or INT8 weights with FP16 activations. During matrix multiplies, the GPU reconstructs each weight on the fly (multiply the stored integer by its scale factor), multiplies against the full-precision activation, and accumulates the result. The dequantization happens inside the compute pipeline, not as a separate storage step.
TinyLlama at Different Precisions
Back to the bandwidth argument from Detail 12. If decode is bandwidth-bound, the time per token cannot be faster than weight_size / bandwidth. Shrink the weights, shrink the floor.
| Precision | Bytes per weight | TinyLlama weight size | Bandwidth floor per token (A100) |
|---|---|---|---|
| FP16 | 2 | 2.2 GB | ~1.41 ms |
| INT8 | 1 | 1.1 GB | ~0.71 ms |
| INT4 | 0.5 | 0.55 GB | ~0.35 ms |
These numbers are approximate. Real quantized formats store scale (and sometimes zero-point) metadata per group, so the file is slightly larger than the naive calculation. And the speedups are bandwidth ceilings, not guarantees. Kernel overhead, dequantization cost, KV-cache traffic, and batch size all affect the actual latency.
But the direction is reliable: if decode is bandwidth-bound, shrinking weights shrinks ITL.
What Quantization Does Not Change
If you only quantize weights, the KV cache stays in FP16 or BF16. For short contexts and small batches, weights dominate memory and quantization helps a lot. But as context length grows or batch size increases, KV cache becomes a larger share of total memory. Quantization shrinks the fixed cost (weights). It does not touch the per-sequence cost (KV cache).
Quantization also does not reduce the number of forward passes. You still read the (now smaller) weights once per token. The cost per read drops, but the number of reads stays the same. Each decode step still produces exactly one token. Keep that in mind for later.
Fitting a 70B Model on One GPU
Scale the TinyLlama math up. A 70B model at FP16 needs about 140 GB just for weights. No single GPU holds that. You would need to shard across at least two 80 GB GPUs, with all the communication overhead that entails.
INT4 quantization brings the weight footprint to roughly 35 GB. Add metadata overhead, call it 38 to 40 GB. That fits on one 80 GB GPU with room left for KV cache and activations. A multi-GPU problem reduced to a single-GPU problem.
The Cost: Quality
Go back to our example. The weight 0.0217 became 0.0236 after INT8 rounding. That's an error of 0.0019. In a single matrix multiply, thousands of such slightly-off weights get multiplied against an input vector. Some round up, some round down. The errors partially cancel. The net effect on any single output value is tiny.
But not all weights are equally forgiving. A weight at 0.0118 sits right between two INT8 grid points. It could round to 0.00787 or 0.01575. Which way it falls can nudge a downstream logit just enough to flip which token wins the softmax. One flipped token early in generation can alter the entire trajectory.
How often this matters depends on how coarse the grid is. INT8 gives you 256 levels. That's fine enough that in practice, you'd struggle to tell the quantized model from the original on most tasks. INT8 is the safe default. INT4 gives you 16 levels. The grid is much coarser, and the rounding errors are larger. For conversational use and general question-answering, INT4 usually works well. For tasks that demand precision (multi-step math, exact code generation, detailed factual recall), you start to see the model slip on cases the full-precision version handles cleanly. It doesn't become stupid. It just gets slightly less reliable on the hard cases.
Below INT4, things get fragile fast. 3-bit and 2-bit quantization are active research areas, but at those levels the grid is so coarse that even careful calibration can only partially compensate.
Calibration itself is not retraining. You run a small representative set of inputs through the model and measure how quantization errors propagate through the layers. The goal is to pick scales (and sometimes identify outlier channels) that minimize the effect on output quality. It takes minutes, not days.
Quantization gives you cheaper weight reads. Each decode step moves fewer bytes, so the bandwidth floor drops and ITL improves. You can also fit larger models on fewer GPUs.
But quantization did not change how many reads you do. The decode loop still produces one token per forward pass. You made each read cheaper, but you are doing just as many of them. Is there a way to get multiple tokens out of a single forward pass?
Detail 19: Speculative Decoding
Quantization (Detail 18) shrinks the weights so each decode step moves fewer bytes. Batching (Details 13 and 14) amortizes weight reads across requests: read the weights once, produce a token for every sequence in the batch. Both attack the bandwidth bottleneck from Detail 12.
But for a single user's stream, each token still costs a full weight read. 200 tokens means 200 reads of 2.2 GB. The math finishes instantly each time; the data transfer doesn't. Is there a way to get more than one token out of that transfer?
The Dependency Chain
The naive idea: run the model once and produce 4 tokens instead of 1.
But think about what that requires. Suppose the prefix is "Write a story. Once upon a" and we want the next 4 tokens. Token 1 might be "time". But token 2 depends on token 1. If token 1 is "time", token 2 is probably ",". If token 1 were "hill", token 2 might be "there" instead. You can't compute the right token 2 without first knowing token 1. Token 3 depends on token 2. And so on.
Generation is sequential. Each token depends on the one before it. That's the fundamental constraint from Detail 4.
But verification is not. If someone hands you a sequence of candidate tokens, you can check all of them in parallel. That asymmetry is the key.
Verification Is Almost Free
Suppose someone handed you 5 candidate tokens and asked: "check if these are correct." Could you do that cheaply?
Recall from Detail 12: during decode, the GPU spends most of its time loading weights from memory, not computing. The compute units are mostly idle. One normal decode step with TinyLlama:
- Read ~2.2 GB of weights from VRAM
- Run the math for one token: ~2.2 billion FLOPs
- Arithmetic intensity: ~1 FLOP/byte
The GPU does one multiply-add, waits for the next byte, does another multiply-add. At ~1 FLOP/byte, it's using maybe 1% of its compute capacity. The other 99% goes to waste.
Now check 5 candidate tokens in a single forward pass, the same way prefill processes a full prompt in one pass (Detail 11). The weight read is still ~2.2 GB. But now the GPU multiplies each weight against 5 token vectors instead of 1. Arithmetic intensity goes up by ~5×. Still well below the GPU's compute ceiling, so the extra math barely adds to wall-clock time.
Verifying 5 tokens costs roughly the same wall-time as generating 1 token. Same weights loaded, same bandwidth consumed. The extra computation was going to waste anyway.
If you have good guesses for the next few tokens, checking them is almost free. The question is where those guesses come from.
Draft, Then Verify
Look at the sentence "Once upon a time, there was a". After "upon a", the word "time" is near-certain. The comma after "time" is obvious. "there" after ", " is a strong guess. These aren't hard predictions. Most of the difficulty in language generation comes from a few pivotal tokens; the rest is syntactic scaffolding, common phrasing, and boilerplate.
A much smaller, simpler model could get these easy tokens right. So use one. Take a 125M-parameter model from the same family as the target. Its weights are ~250 MB, roughly 9× smaller than TinyLlama. Each of its forward passes is proportionally cheaper.
The idea: let the small model run ahead and generate K candidate tokens autoregressively. Then hand those candidates to the real model for verification.
- The small model generates K candidate tokens autoregressively (fast, because it's small).
- The target model runs one forward pass over the prefix plus all K candidates. This costs roughly the same as generating 1 token normally, for the reason we just established.
- Walk the K candidates left to right. At each position, the target model has computed its own logits from the full context. For greedy decoding: if the candidate matches the target's argmax, accept it. If not, use the target's choice at that position and discard the remaining candidates.
A Concrete Trace
Prefix: "Write a story. Once upon a"
Draft. The 125M model generates 4 tokens: ["time", ",", "there", "was"]. Four forward passes, each reading ~250 MB.
Verify. TinyLlama takes the prefix plus all 4 draft tokens and runs one forward pass. It reads the same 2.2 GB it would for a single token, but produces logits at every draft position:
Position 1: draft = "time" → target agrees ✓ accept
Position 2: draft = "," → target agrees ✓ accept
Position 3: draft = "there" → target disagrees ✗ → resample "in" from target
Position 4: draft = "was" → discarded
Result: 3 tokens from one target forward pass (2 accepted + 1 resampled at the rejection point). Without this approach, those 3 tokens would have cost 3 separate weight reads of 2.2 GB each.
At worst, all 4 drafts are wrong. We still get 1 token (resampled at position 1). We never go slower than normal decoding. At best, all 4 drafts are right and we get 5 tokens: the 4 accepted candidates, plus 1 bonus from the target's logits at the last accepted position.
This is speculative decoding. The name borrows from CPU design, where processors speculatively execute instructions before knowing if a branch condition is true, discarding the work if the prediction was wrong.
Does This Change the Output?
A smaller model is proposing tokens and we're sometimes accepting them. Does the output degrade to the draft model's quality?
Think about what the target does during verification. At each position, it computes its own logits from the full context. If the draft happened to guess the target's top token, we accept it. If not, we use the target's choice. Either way, the token that enters the sequence is the one the target would have picked. The draft affected which tokens got checked, not which ones got emitted.
For greedy decoding, the output is exactly what the target model would have produced on its own. The draft model only affected speed, not the result.
For sampling (temperature, top-p), the guarantee still holds through a rejection sampling scheme: tokens where the draft was overconfident get accepted less often, correcting for the bias. The math ensures the final distribution is identical to sampling from the target directly.
At worst, all draft tokens are rejected and we resample from the target at position 1. Cost: one token from one forward pass, the same as normal decoding. Speculative decoding never makes things slower.
System Implications
Running two models means memory for both weight sets and two KV caches. The draft model is kept small (and often quantized) to minimize overhead:
Target weights (TinyLlama): 2.2 GB
Draft weights (125M): ~250 MB
Total: ~2.45 GB (plus KV caches for both)
When a draft token is rejected, the draft model's KV cache rolls back to the last accepted position. Why roll back? If token 3 was rejected and resampled, the draft's predictions for tokens 4 onward were conditioned on the wrong token 3. Those cache entries are invalid. The draft clears back to the last accepted position and re-drafts from there. The target cache only commits accepted entries.
When It Helps (and When It Doesn't)
The speedup depends on one number: the acceptance rate, how often the draft model's guess matches what the target would have produced. With K = 5 draft tokens:
| Acceptance rate | Expected tokens per target forward pass |
|---|---|
| 80% | ~4 |
| 50% | ~2 |
| 20% | ~1.2 (barely helps) |
Works well when:
- The draft model is from the same family as the target (high agreement)
- Output is locally predictable: boilerplate, common phrasing, structured text, code
- Batch sizes are small (decode is bandwidth-bound, so the "free compute" that makes verification cheap is actually available). With small batches, the GPU's compute units are underutilized (Detail 12), so verification truly costs nothing extra. With large batches, the GPU is already busy, and the extra computation from verification starts competing for resources.
Works poorly when:
- The draft model is weak relative to the target (low acceptance rate means most draft work is wasted)
- Large batches already saturate the GPU's compute (verification is no longer "free" because the GPU is busy serving many sequences per step)
- Memory is tight (two models plus two KV caches may not fit)
Speedups of 2-3× on decode-heavy, low-batch workloads are common.
Variants
What if you don't want a separate draft model?
Self-speculative decoding. The target model drafts for itself by skipping its later transformer layers. Run only the first N layers (say 10 of TinyLlama's 22) for drafting. Early layers see mostly local syntax, enough to predict function words and common continuations. Then verify with all 22 layers. No extra model, no extra memory for a second set of weights. Trade-off: early layers capture less context than the full stack, so acceptance rates tend to be lower.
Medusa. Attach lightweight prediction heads to the target model's final hidden layer. Head 1 predicts t+1, Head 2 predicts t+2, and so on. Each head proposes several candidates, forming a tree of possible continuations that gets verified in one forward pass. The limitation: the heads are independent. Head 3 doesn't know what Heads 1 and 2 predicted. It's guessing 3 steps ahead without knowing the intermediate tokens, which limits accuracy for later positions.
EAGLE. Instead of predicting future tokens directly (hard, since many tokens are plausible at each step), predict future hidden states (the model's internal representations). Hidden states are continuous and smooth: similar inputs produce similar outputs. A small autoregressive head predicts the next hidden state given both the current state and the embedding of the token actually sampled. That conditioning on the chosen token resolves the ambiguity that Medusa's independent heads cannot handle. EAGLE variants report 3-5× speedups.
Lookahead decoding. Uses an iterative refinement technique (Jacobi iteration) that produces candidate token sequences as a natural byproduct. No separate model, no extra heads. Verify those candidates against the target. Works best on repetitive or structured text where the same patterns recur.
Detail 20: Prefix Caching
The KV cache (Detail 10) eliminated redundant work within a single request: once a token's Keys and Values are computed, the model never recomputes them for subsequent tokens. But the savings stop at the request boundary. Every new prompt triggers a fresh prefill (Detail 11), even if the server has seen most of those tokens minutes ago.
Consider a chatbot whose system prompt (instructions, safety guidelines, tool descriptions) totals 500 tokens. Every user message adds another 10 to 50 tokens on top. At 100 requests per minute, the server prefills 50,000 system-prompt tokens against roughly 2,000 unique user-message tokens. About 96% of the prefill computation is duplicate work, producing KV entries the server has already computed for the previous request.
Can we compute the system prompt's KV entries once and hand them to every subsequent request? The idea is tempting, but it raises an immediate question: are those entries merely similar across requests, or truly identical? If similar, reuse would be an approximation. If identical, reuse is exact.
Recall the causal mask from Detail 5. Position can only attend to positions through . Information flows strictly forward, never backward. That means the Key and Value vectors at position depend on exactly one thing: the tokens at positions through . What the user types at position 501 cannot reach back and change the KV entry at position 200. So for two requests that share the same first 500 tokens, the KV entries at every shared position are byte-identical. Not approximately equal. Identical. The causal mask makes this a hard guarantee.
Compute once, store, attach to subsequent requests. This is prefix caching.
Here is how it plays out. Request 1 arrives: a 500-token system prompt followed by "What is the capital of France?" (12 tokens), totaling 512 tokens. The prefix cache is empty, so the server runs full prefill on all 512 tokens. At 22 KB per token position (Detail 10), the system prompt's KV footprint is 500 × 22 KB = 11 MB. The server hashes the 500-token prefix and stores a reference to those KV blocks. In systems using PagedAttention (Detail 16), the cached entries live as blocks that can be shared across requests without copying, the same virtual-memory trick that already manages per-request KV memory.
Request 2 arrives seconds later: the same 500-token system prompt followed by "Explain gravity" (8 tokens). The server hashes the first 500 tokens, finds a match, and attaches the cached KV blocks without recomputing them. Prefill runs only on the 8 new tokens. The workload drops from 508 tokens to 8, a 98% reduction. Since prefill dominates TTFT (Detail 11), time-to-first-token drops roughly in proportion.
The improvement scales with the ratio of shared prefix to unique suffix. A 500-token prefix with a 10-token suffix means prefill handles 2% of the original tokens. A 100,000-token prefix with a 200-token suffix means 0.2%. For a real production data point: Anthropic reports that a 100K-token cached prompt reduces TTFT from roughly 11.5 seconds to roughly 2.4 seconds. Inter-token latency is unchanged. Decode is still bandwidth-bound on model weights (Detail 12), and prefix caching does not touch the decode phase.
There is an important constraint: cache hits require exact token matches. Prefix caching operates on token sequences, not semantic similarity. A single extra space, a different quote style, or an injected timestamp changes the token IDs and breaks the cache from that point forward. To get reliable hits, the shared prefix must be byte-identical before tokenization.
A few practical realities shape deployment. The cache consumes the same scarce GPU memory discussed in Detail 15, so servers use LRU or TTL eviction to keep only the most frequently reused prefixes resident. Because the cache operates at PagedAttention block granularity (Detail 16), a 500-token prefix occupies blocks. And if cached prefixes include user-provided text, many production systems scope caches per tenant or restrict caching to known-safe system prompts, treating the cache as user data.
With prefix caching handling redundant prefill across requests, we have optimized both phases of inference for individual requests. Detail 21 steps back to look at the system-level trade-off between latency and throughput, where the interests of a single user and the interests of the server diverge.
Detail 21: Latency vs Throughput Trade-offs
Every optimization so far has a beneficiary. Sometimes it is the person typing the prompt. Sometimes it is the company paying for the GPU. Often, making one happier makes the other worse.
What the User Feels
Imagine typing a prompt into a chat interface and hitting send. There is a pause while the server processes your prompt (prefill). Then the first token appears. After that, tokens stream in one by one, each arriving after a short gap.
Two distinct frustrations can arise. If the initial pause drags, the interface feels unresponsive before it even starts. If the gaps between tokens stutter or stretch, the streaming feels choppy, even if the total number of tokens arrives in the same wall-clock time. Both are bad, but they stem from different parts of the system.
The initial pause is called TTFT (time to first token). It equals prefill time plus any time the request spent waiting in a queue. The gap between subsequent tokens is called ITL (inter-token latency). It equals the decode step time plus scheduling jitter.
For TinyLlama on a single A100, serving one user with no contention:
- Prefill for a 100-token prompt: ~5 ms (compute-bound, the GPU tears through it)
- TTFT: ~5 ms (no queue, no waiting)
- ITL: ~1.4 ms (2.2 GB of weights / 1,555 GB/s of bandwidth, per Detail 12)
- A 200-token response: 5 + (200 × 1.4) ≈ 285 ms end-to-end
For a single user with no competition for the GPU, both metrics are excellent. This will not survive contact with a second user.
What the GPU Sees
The user cares about their stream. The operator cares about the whole GPU. The operator's metric is throughput: total output tokens generated per second, summed across all active requests.
Latency is per-stream. Throughput is per-GPU.
At one request, our A100 generates about 710 tokens/sec for that single stream. That is a fine number for the user, but the GPU is almost entirely idle. Detail 12 showed that a single decode stream uses less than 1% of the A100's compute capacity. The arithmetic units have nothing to do while the memory bus ferries weights. From the operator's perspective, serving one user on an A100 is like reserving an entire highway for a single car.
Batching (Detail 14) fills that idle capacity by processing multiple sequences per weight read. But every new sequence adds its own KV cache reads to the bandwidth bill.
Where the Tension Comes From
Why does per-stream ITL get worse as you add more sequences to the batch?
During decode, each step reads the model weights once: 2.2 GB for TinyLlama, regardless of batch size. Those weights are shared across all sequences. But the attention mechanism also reads each sequence's KV cache, its private record of all prior keys and values. More sequences means more KV data to read per step.
TinyLlama with GQA uses 22 KB of KV cache per token position. A sequence with 300 tokens of context (100-token prompt plus 200 generated so far) carries about 6.6 MB of KV cache. That is tiny next to 2.2 GB of weights. But multiply it by the batch size:
| Batch size | Weight read | KV read | Total read | Step time | Total tok/s | Per-stream tok/s |
|---|---|---|---|---|---|---|
| 1 | 2.2 GB | 7 MB | 2.21 GB | 1.42 ms | 704 | 704 |
| 8 | 2.2 GB | 53 MB | 2.25 GB | 1.45 ms | 5,517 | 690 |
| 32 | 2.2 GB | 211 MB | 2.41 GB | 1.55 ms | 20,645 | 645 |
| 128 | 2.2 GB | 845 MB | 3.05 GB | 1.96 ms | 65,306 | 510 |
At batch size 1, the GPU reads 2.2 GB and produces one token. At batch size 128, it reads 3.05 GB and produces 128 tokens. Throughput jumps 93× while per-stream ITL only increases 38%. The economics are lopsided: 38% more latency per stream buys 93× more total output.
But notice the trend. Weight reads are constant. KV reads scale linearly with batch size. At batch 128, the KV traffic is already 845 MB, nearly 40% of the weight read. Push the batch larger and KV traffic eventually overtakes weight traffic. From that point, each additional sequence directly inflates step time.
This is the same roofline picture from Detail 12, now applied to the batch. Below a critical batch size, each new sequence adds so little extra bandwidth demand that step time barely moves. Throughput scales nearly for free.
Above that critical batch size, the GPU transitions from bandwidth-bound to a regime where the growing KV reads dominate. Every additional sequence pushes step time higher in proportion.
For TinyLlama with short contexts on an A100, the critical batch size is quite large because GQA keeps the per-sequence KV footprint small. For a model with full multi-head attention, longer contexts, or both, the crossover comes much sooner.
Prefill-Decode Interference
The table above assumed a batch of sequences all in the decode phase, generating one token per step. But continuous batching (Detail 14) means new requests arrive constantly, and each one needs prefill before it can join the decode loop.
Here is the problem. Prefill is compute-heavy: it processes the entire prompt in one large matrix multiplication. Decode is bandwidth-heavy: it reads weights and KV cache to generate one token per sequence. When a long prefill lands in the middle of a decode batch, it monopolizes the GPU pipeline for its duration. Every active decode stream has to wait.
Concrete scenario: 32 sequences are decoding smoothly at ~1.55 ms per step. A new request arrives with a 1,000-token prompt. Its prefill takes roughly 50 ms. During those 50 ms, all 32 streams produce zero tokens. Their ITL for that step spikes from 1.55 ms to 50+ ms, a 30× jump. The user sees a visible stutter.
The naive fix is to limit how many tokens get prefilled at once. But throttling prefill hurts TTFT for the incoming request.
A better approach is chunked prefill. Instead of processing the full 1,000-token prompt in one shot, the scheduler breaks it into smaller chunks (say, 128 tokens each) and interleaves each chunk with a decode step.
This works because decode and prefill waste different resources. Decode is bandwidth-bound: the compute units sit mostly idle while waiting for weight reads. Prefill is compute-bound: it has plenty of arithmetic to do but draws less bandwidth per FLOP. By interleaving them, the prefill chunk fills the idle compute that decode leaves on the table, while decode uses the bandwidth that prefill does not fully need. The two workloads complement each other rather than competing.
The result: decode streams see small, predictable ITL bumps instead of one catastrophic spike. TTFT for the new request is slightly higher than a single monolithic prefill (the prefill work is spread across several steps), but that trade is almost always worth it.
Queueing
Everything above assumed the request was already running. In a real service, requests arrive over time and compete for slots in the decode batch.
With continuous batching, a new request starts only after three things happen:
- A slot opens (an existing sequence finishes and frees its KV memory)
- Prefill runs for the new prompt
- The new sequence joins the decode loop
If slots are available, this happens immediately and TTFT is just prefill time. If all slots are full, the request waits in a queue. TTFT becomes prefill time plus queue time, and queue time can dwarf prefill.
How fast does queue time grow? Think about it from the arriving request's perspective. When utilization is low, most slots are open, so a new request almost always finds one immediately. As utilization rises, open slots become rare. Near 100%, almost every request has to wait, and each waiting request makes the line longer for the next one. Requests pile up behind each other.
The relationship is nonlinear because you are dividing by an ever-shrinking gap. Queue wait scales roughly as , where is utilization (fraction of capacity in use). At 50% utilization the ratio is 1. At 90% it is 9. At 95% it is 19. The curve is a hockey stick.
In practice:
- At 70% utilization: queue is usually short, TTFT ≈ bare prefill time
- At 90%: p99 wait times climb sharply
- At 95%+: p99 TTFT can reach seconds, even while average throughput looks healthy
This is the fundamental tension of serving systems. High utilization means high throughput and good hardware efficiency. But p99 TTFT can blow up well before the GPU hits 100%, because queueing math is brutal at the margins.
Goodput
Consider two server configurations:
- Config A: 50,000 tok/s throughput, p99 TTFT = 4 seconds
- Config B: 20,000 tok/s throughput, p99 TTFT = 200 ms
Config A moves more tokens. But if your application requires TTFT under 500 ms, Config A is useless. Its throughput is wasted because the latency target is violated.
The throughput that still meets your latency targets is called goodput. Config A has zero goodput (misses the SLO). Config B has 20,000 tok/s of goodput.
Raw throughput is a misleading north star. A server pinned at 95% utilization with high total tok/s but terrible tail latency is not serving users well. Goodput forces you to account for both dimensions at once.
Picking an Operating Point
What does this mean in practice? Consider two extremes. An interactive chatbot has a human watching the stream. Every millisecond of TTFT feels like lag, and any stutter in token delivery breaks the illusion of fluency. The operator caps the batch size well below the maximum, leaves headroom for new arrivals, and uses chunked prefill to smooth ITL spikes. Utilization stays moderate. That is the price of responsiveness.
At the other end, a batch-processing pipeline summarizes a million documents overnight. No one is watching. TTFT and ITL are irrelevant. The only metric that matters is tokens per dollar, so the operator fills every slot, pushes utilization as high as queueing allows, and maximizes total throughput.
Most real systems serve a mix. The usual solution is to separate the pools: give interactive traffic a strict concurrency cap with headroom for fast admission, and route batch traffic to a throughput-maximizing pool where high utilization is acceptable.
Batch size is the primary lever. Below the critical batch size, you get both low latency and rising throughput, nearly for free. Above it, every additional sequence trades per-stream ITL for total tok/s. Queueing amplifies the problem at high utilization: even moderate overcommitment can push tail latency far beyond acceptable levels. The operating point you choose depends on what your users are doing, and on one more constraint we have not yet addressed: whether the model fits on a single GPU at all.
Detail 22: Multi-GPU Inference
Detail 21 ended with a question: what if the model does not fit on a single GPU at all?
A 70B model at FP16 needs 70B × 2 bytes = 140 GB just for weights. An A100 has 80 GB of VRAM. The model does not fit.
Even with INT4 quantization (Detail 18), 70B shrinks to roughly 35 GB. That fits in VRAM with room for KV cache. But there is a second problem. Detail 12 established that decode reads all model weights every token. The bandwidth floor for 70B INT4 on a single A100 is 35 GB / 1,555 GB/s ≈ 22.5 ms per token. If your SLO requires inter-token latency under 15 ms, one GPU cannot deliver it. Not a software limitation. Physics.
Multi-GPU inference solves both problems. How you split the model determines which one you fix.
Two Ways to Split
A transformer is a stack of identical layers. There are two ways to distribute them across GPUs:
- Give each GPU a different range of layers. Data flows through GPUs in sequence.
- Give all GPUs a slice of every layer. They work on the same token simultaneously.
The first sounds like the obvious approach. The second turns out to be the important one. Let's see why.
Splitting Across Layers
Start with the naive approach. You have 4 GPUs and a 32-layer model. Assign 8 layers to each GPU:
- GPU 0: layers 1 to 8
- GPU 1: layers 9 to 16
- GPU 2: layers 17 to 24
- GPU 3: layers 25 to 32
A request arrives. GPU 0 processes its 8 layers and passes activations to GPU 1. GPU 1 processes layers 9 to 16 and passes to GPU 2. And so on. Communication is light: one activation tensor at each stage boundary.
But trace a single decode step for one request. GPU 0 works on layers 1 to 8. GPUs 1, 2, and 3 are idle. Then GPU 1 works on layers 9 to 16. GPUs 0, 2, and 3 are idle. At any moment, 3 out of 4 GPUs are doing nothing. That is 75% waste. This idle time is the pipeline bubble.
How do you fill the bubble? Send more items through the pipeline. If you have multiple sequences in flight, GPU 0 can start processing sequence 2 while GPU 1 handles sequence 1. This is microbatching, and continuous batching (Detail 14) provides these microbatches naturally since you already have many active sequences in the decode loop.
The bubble fraction follows a simple formula: (p - 1) / (m + p - 1), where p is the number of pipeline stages and m is the number of microbatches.
With 4 stages:
| Microbatches (m) | Bubble fraction | GPU utilization |
|---|---|---|
| 1 | 75% idle | 25% |
| 4 | 43% idle | 57% |
| 16 | 16% idle | 84% |
| 64 | ~4% idle | ~96% |
But notice what this means for a single request. One request still flows through all 4 stages sequentially. Latency equals the sum of stage times. Microbatching fills the idle GPUs with other requests, improving throughput. It does not make any individual request faster.
This approach is called pipeline parallelism (PP). It solves the capacity problem. It improves throughput. It does not reduce single-request inter-token latency.
Splitting Within Layers
Pipeline parallelism could not reduce per-token latency because only one GPU worked on each token at a time. What if all GPUs worked on the same token simultaneously?
Start from a single matrix multiply: y = Wx. Suppose W is a large weight matrix and you have 2 GPUs. Split W column-wise: GPU 0 gets the left half, GPU 1 gets the right half. Each GPU multiplies its half of W by x to get its half of y. Concatenate the two halves. Done. For this single operation, no communication needed.
But the transformer MLP has two consecutive multiplies: an up-projection followed by a down-projection, with an activation function between them. Something like output = GeLU(x · A) · B. The Megatron-LM insight: use column-parallel on A (each GPU gets columns of A, producing part of the hidden representation) and row-parallel on B (each GPU gets rows of B). The row-parallel step produces partial sums: each GPU has computed part of the final dot product along the inner dimension. To get the correct result, you must sum across GPUs. This is an all-reduce.
The cost: 2 all-reduces per transformer layer (one for the attention projection, one for the MLP). For TinyLlama's 22 layers, that is 44 all-reduces per decode step.
The key difference from the pipeline approach: all GPUs work on the same token at the same time. No bubble. But you pay for communication every layer, not just at stage boundaries.
This is tensor parallelism (TP).
The Bandwidth Payoff
This is where the two ideas connect.
In Detail 12 we established that decode reads all model weights every token. The bandwidth floor for TinyLlama on an A100 is ~1.41 ms per token. No software optimization breaks that floor on a single GPU.
With 2-way tensor parallelism, each GPU holds half the weights: 1.1 GB instead of 2.2 GB. Each GPU reads its 1.1 GB from its own VRAM. The bandwidth floor per GPU drops to ~0.71 ms. And since both GPUs work in parallel, the effective floor for the whole model is ~0.71 ms. ITL roughly halved.
| TP degree | Weight read per GPU | Bandwidth floor (A100) |
|---|---|---|
| 1 (single GPU) | 2.2 GB | ~1.41 ms |
| 2-way | 1.1 GB | ~0.71 ms |
| 4-way | 0.55 GB | ~0.35 ms |
Why does this work? Recall from Detail 12 that decode at batch size 1 has an arithmetic intensity of ~1 FLOP/byte. The GPU finishes the actual arithmetic in microseconds and spends the rest of the time waiting for weight data to arrive from VRAM. The compute units are 99% idle, starving for bytes. Splitting weights across GPUs splits the waiting. Each GPU has less to read, and they all read simultaneously.
This is a genuine per-token latency improvement. Pipeline parallelism makes the pipeline longer but each stage does less work, so total latency stays roughly the same. Tensor parallelism makes all GPUs work on the same token in parallel, splitting the bandwidth cost.
Now close the loop from the opening. A 70B model at INT4 on a single A100: ~22.5 ms per token. 2-way TP: ~11.3 ms. 4-way TP: ~5.6 ms. 8-way TP in a DGX node (8 A100s): ~2.8 ms. The latency target that was impossible on one GPU becomes comfortable on eight.
The Communication Tax
Those 44 all-reduces per decode step for TinyLlama are not free. How expensive they are depends entirely on the interconnect.
Within a node, NVLink provides ~900 GB/s on H100 (bidirectional). The activation tensors involved in inference all-reduces are small. Overhead is a few percent of decode time. The bandwidth payoff dominates easily.
Across nodes, interconnect bandwidth drops to 50 to 400 GB/s (Ethernet or InfiniBand). All-reduces become a significant fraction of decode time and can erase the bandwidth savings entirely.
Pipeline parallelism communicates only at stage boundaries: one point-to-point activation transfer between adjacent stages per step. Far less frequent than tensor parallelism's per-layer all-reduces. PP tolerates slower interconnects.
This creates the standard deployment pattern: tensor parallelism within a node (where GPUs are connected by fast NVLink), pipeline parallelism across nodes (where communication is infrequent and can tolerate slower links). Match the parallelism strategy to the communication topology.
Large deployments mix strategies: 8-way tensor parallelism within a DGX node (8 GPUs connected by NVLink), pipeline parallelism across nodes if the model is truly huge, and multiple replicas of the whole setup for throughput scaling. The combination gives you capacity, low latency, and high throughput.
The underlying principle is the same one from Detail 12: decode is bandwidth-bound, so anything that reduces or parallelizes the weight read reduces inter-token latency. Multi-GPU inference is the final lever. When one GPU cannot hold the model or cannot deliver enough bandwidth for your latency target, you split the model across devices. How you split depends on your hardware topology, but the bottleneck you are attacking is always the same.
Part 2 Summary: Production Optimizations as Levers
Part 1 gave you the mechanism. Part 2 was the systems view: what happens when that same prefill-decode loop runs for thousands of users at once, and you start caring about VRAM, bandwidth, and queues.
A good way to remember the optimizations is by which constraint they attack:
- Amortize weight reads (throughput): batching and continuous (in-flight) batching.
- Make KV cache fit (concurrency): GQA/MQA, prefix caching, and PagedAttention to avoid fragmentation.
- Make attention feasible at long context: FlashAttention-class kernels.
- Move fewer bytes per token (bandwidth): quantization, and (at the extreme) sharding weights via tensor parallelism.
- Do fewer decode steps (latency): speculative decoding.
- Keep tail latency sane under load: chunked prefill, admission control, and preemption.
When you see "2× faster" in a benchmark, ask what moved: TTFT, ITL, throughput, p99 latency, or cost per token. They are different knobs.
Production Optimization Impact (very workload-dependent):
| Technique | What It Does | Typical effect |
|---|---|---|
| Continuous Batching | Dynamic request scheduling | 2-5× throughput |
| PagedAttention | Eliminate KV cache fragmentation | 2-4× memory efficiency |
| FlashAttention | Fused attention kernels | 2-4× attention speed |
| Quantization | Reduce precision | 2-4× smaller weights (often faster) |
| GQA/MQA | Smaller KV per head | 4-32× KV cache reduction |
| Speculative Decoding | Parallel verification | 2-3× decode speed |
| Prefix Caching | Reuse shared context | 5-10× TTFT for cached prefixes |
| Multi-GPU (TP) | Split weights across devices | Near-linear ITL reduction (until comm dominates) |
Used together, these often compound because they target different bottlenecks. There are no infinite wins: eventually bandwidth, VRAM, or queueing becomes the binding constraint. But none of these techniques are mysterious once you have the Part 1 mental model. Each one is a direct response to a specific bottleneck we identified.
Summary: From Black Box to a Useful Mental Model
We started with a black box: text goes in, text comes out. Twenty-two details later, there is no box left.
If you collapse the whole post into one sentence:
Inference is a prefill pass that builds a KV cache, followed by a bandwidth-limited decode loop that repeatedly rereads weights and appends to that cache.
That sentence would have meant nothing at the start of this post. Now every word in it maps to something concrete: you know what prefill does and why it's compute-bound, what the KV cache stores and how it grows, why decode is bandwidth-limited, and what "appends to that cache" means at the hardware level.
Most real questions about inference performance reduce to that loop:
- Why is TTFT high? Prefill is heavy, or the request waited in a queue.
- Why is ITL high? You're bandwidth-limited, or paying KV traffic and synchronization overhead.
- Why is throughput flat? You're no longer amortizing the weight read, or KV memory is throttling batch size.
- Why did switching to GQA help? Smaller KV per head, more sequences fit in memory, larger batches, better bandwidth utilization.
You can keep pulling on threads like these and the answers trace back to the same fundamentals.
This was a long read. If you worked through it, you now carry the same mental model that people building inference systems use daily. Not a surface-level familiarity with the terminology, but the actual framework: prefill vs decode, bandwidth bounds, KV cache budgets, the tension between latency and throughput. When you read a vLLM changelog, a quantization paper, or a benchmark comparing inference backends, you have the vocabulary to follow the argument and the intuition to evaluate it on your own terms.
Optional: Going Deeper
What you know so far should be enough to build intuition and read most papers in this inference optimization space. Everything below are optional deep dives for readers who want more detail on specific topics.
References & Further Reading
Some of these I read cover to cover, others I skimmed for specific details, and the rest are here because they are the primary sources for concepts discussed above. Included for anyone who wants to go deeper on a particular topic.
Blog posts & videos
- Aleksa Gordić, Inside vLLM: Anatomy of a High-Throughput LLM Inference System. https://www.aleksagordic.com/blog/vllm
- Vizuara, How the VLLM Inference Engine Works. https://www.youtube.com/watch?v=QyHHbeXqgrQ
- Lilian Weng, Large Transformer Model Inference Optimization. https://lilianweng.github.io/posts/2023-01-10-inference-optimization/
- NVIDIA Developer Blog, Mastering LLM Techniques: Inference Optimization. https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/
- Jay Alammar, The Illustrated Transformer. https://jalammar.github.io/illustrated-transformer/
- Andrej Karpathy, Let's build GPT: from scratch, in code, spelled out. https://www.youtube.com/watch?v=kCc8FmEb1nY
Papers (mostly arXiv)
Core architecture
- Vaswani et al. (2017). Attention Is All You Need. https://arxiv.org/abs/1706.03762
- Bahdanau et al. (2014). Neural Machine Translation by Jointly Learning to Align and Translate. https://arxiv.org/abs/1409.0473
- Su et al. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding. https://arxiv.org/abs/2104.09864
- He et al. (2016). Deep Residual Learning for Image Recognition. https://arxiv.org/abs/1512.03385
- Ba et al. (2016). Layer Normalization. https://arxiv.org/abs/1607.06450
- Shazeer (2020). GLU Variants Improve Transformer (SwiGLU). https://arxiv.org/abs/2002.05202
Tokenization
- Sennrich et al. (2016). Neural Machine Translation of Rare Words with Subword Units (BPE). https://arxiv.org/abs/1508.07909
- Kudo & Richardson (2018). SentencePiece: A Simple and Language Independent Subword Tokenizer and Detokenizer for Neural Text Processing. https://arxiv.org/abs/1808.06226
Models referenced
- Zhang et al. (2024). TinyLlama: An Open-Source Small Language Model. https://arxiv.org/abs/2401.02385
- Touvron et al. (2023). Llama 2: Open Foundation and Fine-Tuned Chat Models. https://arxiv.org/abs/2307.09288
- Llama Team (2024). The Llama 3 Herd of Models. https://arxiv.org/abs/2407.21783
- Jiang et al. (2023). Mistral 7B. https://arxiv.org/abs/2310.06825
- Gemma Team (2024). Gemma: Open Models Based on Gemini Research and Technology. https://arxiv.org/abs/2403.08295
Inference + serving
- Kwon et al. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention (vLLM). https://arxiv.org/abs/2309.06180
- Yu et al. (2022). Orca: A Distributed Serving System for Transformer-Based Generative Models (OSDI). https://www.usenix.org/conference/osdi22/presentation/yu
Attention kernels
- Dao et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. https://arxiv.org/abs/2205.14135
- Dao (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. https://arxiv.org/abs/2307.08691
- Milakov & Gimelshein (2018). Online Normalizer Calculation for Softmax. https://arxiv.org/abs/1805.02867
KV-cache shape tricks
- Shazeer (2019). Fast Transformer Decoding: One Write-Head Is All You Need (MQA). https://arxiv.org/abs/1911.02150
- Ainslie et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. https://arxiv.org/abs/2305.13245
- DeepSeek-AI (2024). DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model (MLA). https://arxiv.org/abs/2405.04434
Speculative decoding
- Leviathan et al. (2022). Fast Inference from Transformers via Speculative Decoding. https://arxiv.org/abs/2211.17192
- Chen et al. (2023). Accelerating Large Language Model Decoding with Speculative Sampling. https://arxiv.org/abs/2302.01318
- Li et al. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. https://arxiv.org/abs/2401.15077
- Cai et al. (2024). Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads. https://arxiv.org/abs/2401.10774
- Fu et al. (2024). Break the Sequential Dependency of LLM Inference Using Lookahead Decoding. https://arxiv.org/abs/2402.02057
Quantization
- Frantar et al. (2022). GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers. https://arxiv.org/abs/2210.17323
- Xiao et al. (2022). SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models. https://arxiv.org/abs/2211.10438
- Lin et al. (2023). AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration. https://arxiv.org/abs/2306.00978
- Dettmers et al. (2022). LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale. https://arxiv.org/abs/2208.07339
Parallelism
- Shoeybi et al. (2019). Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. https://arxiv.org/abs/1909.08053
- Huang et al. (2018). GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism. https://arxiv.org/abs/1811.06965
Performance modeling
- Williams et al. (2009). Roofline: An Insightful Visual Performance Model for Multicore Architectures. Communications of the ACM, 52(4), 65-76. https://doi.org/10.1145/1498765.1498785
Alternate attention (context for long sequences)
- Kitaev et al. (2020). Reformer: The Efficient Transformer. https://arxiv.org/abs/2001.04451
- Wang et al. (2020). Linformer: Self-Attention with Linear Complexity. https://arxiv.org/abs/2006.04768
- Beltagy et al. (2020). Longformer: The Long-Document Transformer. https://arxiv.org/abs/2004.05150
Compression beyond quantization (optional)
- Sanh et al. (2019). DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter. https://arxiv.org/abs/1910.01108
- Frantar & Alistarh (2023). SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot. https://arxiv.org/abs/2301.00774
- Shazeer et al. (2017). Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. https://arxiv.org/abs/1701.06538
- Fedus et al. (2021). Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. https://arxiv.org/abs/2101.03961
Implementations & docs
- vLLM: https://github.com/vllm-project/vllm
- FlashAttention: https://github.com/Dao-AILab/flash-attention
- Hugging Face TGI: https://github.com/huggingface/text-generation-inference
- NVIDIA TensorRT-LLM: https://github.com/NVIDIA/TensorRT-LLM
- LMCache (prefix caching): https://github.com/LMCache/LMCache
- XGrammar (structured generation): https://arxiv.org/abs/2411.15100