"Attention is All You Need" — Multi-Head & Self-Attention
Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch
"Attention is All You Need" — Multi-Head & Self-Attention
Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch
"Attention is All You Need" — Multi-Head & Self-Attention
Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch
Module Objective
Master Multi-Head Self-Attention — the core parallel computation engine of Transformers — with math, code, intuition, and divide-and-conquer parallelism.
1. Self-Attention: One Token Talks to All
Self-Attention = $ Q = K = V $
Every token attends to every other token in the same sequence.
# Input: [batch, seq_len, d_model]
X → Linear → Q, K, V → Attention(Q, K, V)
2. Why Multi-Head? Divide & Conquer
| Problem | Solution |
|---|---|
| One attention head = one perspective | Multiple heads = multiple subspaces |
| Risk of missing relations | Parallel views → richer representation |
"Let the model attend to information from different representation subspaces at different positions."
— Vaswani et al., 2017
3. Multi-Head Attention — The Formula
$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O
$$
$$
\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
$$
Where:
- $ h $ = number of heads
- $ d_k = d_v = d_{\text{model}} / h $
- $ W_i^Q, W_i^K, W_i^V \in \mathbb{R}^{d \times d_k} $
- $ W^O \in \mathbb{R}^{h d_v \times d} $
4. Step-by-Step: From Single to Multi-Head
| Step | Operation | Shape |
|---|---|---|
| 1 | Project $ X $ → $ Q, K, V $ | $ (B, N, d) $ |
| 2 | Split into $ h $ heads | $ (B, h, N, d/h) $ |
| 3 | Parallel attention | $ h $ heads → $ (B, h, N, d/h) $ |
| 4 | Concat + Linear | $ \to (B, N, d) $ |
5. Multi-Head Attention — From Scratch (PyTorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # d_v = d_k
# Learnable projections
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
self.scale = (self.d_k) ** 0.5
def split_heads(self, x):
"""Split last dim into (num_heads, d_k)"""
batch, seq_len, _ = x.shape
x = x.view(batch, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2) # (B, h, N, d_k)
def combine_heads(self, x):
"""Combine heads back to (B, N, d_model)"""
batch, _, seq_len, _ = x.shape
x = x.transpose(1, 2).contiguous()
return x.view(batch, seq_len, self.d_model)
def scaled_dot_product(self, Q, K, V, mask=None):
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # (B, h, N, N)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
return torch.matmul(attn, V), attn # (B, h, N, d_k), (B, h, N, N)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 1. Linear projections
Q = self.W_q(Q) # (B, N, d)
K = self.W_k(K)
V = self.W_v(V)
# 2. Split into heads
Q = self.split_heads(Q) # (B, h, N, d_k)
K = self.split_heads(K)
V = self.split_heads(V)
# 3. Apply attention in parallel
attn_output, attn_weights = self.scaled_dot_product(Q, K, V, mask)
# attn_output: (B, h, N, d_k)
# 4. Combine heads
output = self.combine_heads(attn_output) # (B, N, d)
# 5. Final linear
output = self.W_o(output)
return output, attn_weights
6. Self-Attention = Multi-Head(Q=X, K=X, V=X)
# Self-Attention
x = torch.randn(2, 10, 512) # (batch, seq_len, d_model)
mha = MultiHeadAttention(d_model=512, num_heads=8)
output, attn = mha(x, x, x) # Q=K=V=x
print(output.shape) # (2, 10, 512)
7. Parallelism: Divide & Conquer
Hardware View (GPU)
Input X → [Linear Q] → Split → [Head 1] → Attention
[Linear K] → Split → [Head 2] → Attention → Concat → W^O
[Linear V] → Split → [Head 3] → Attention
...
[Head 8] → Attention
All 8 heads run in parallel on GPU
Memory: $ O(h \cdot N^2) $ → still $ O(N^2) $, but richer features
8. Visualization: Multi-Head Attention Maps
import matplotlib.pyplot as plt
import seaborn as sns
# Dummy input
x = torch.randn(1, 5, 64)
mha = MultiHeadAttention(d_model=64, num_heads=8)
_, attn_weights = mha(x, x, x) # (B, h, N, N)
# Plot all heads
fig, axes = plt.subplots(2, 4, figsize=(16, 6))
axes = axes.flatten()
for i in range(8):
sns.heatmap(
attn_weights[0, i].detach().cpu(),
ax=axes[i],
cmap="viridis",
cbar=False
)
axes[i].set_title(f"Head {i+1}")
axes[i].set_xticks([])
axes[i].set_yticks([])
plt.suptitle("Multi-Head Attention Weights (8 Heads)", fontsize=16)
plt.tight_layout()
plt.show()
Each head learns different patterns:
- Head 1: Local
- Head 2: Global
- Head 3: Syntax
- etc.
9. Efficiency: Memory & Compute
| Operation | Time | Memory |
|---|---|---|
| Linear Projections | $ O(N d^2) $ | $ O(N d) $ |
| Split Heads | $ O(N d) $ | $ O(N d) $ |
| Attention (per head) | $ O(N^2 d/h) $ | $ O(N^2) $ |
| Total | $ O(N^2 d) $ | $ O(N^2 + N d) $ |
Same complexity as single head, but richer output
10. Divide & Conquer Intuition
Single Head (64-dim):
"the cat sat on the mat"
└────┬────┘
One view
Multi-Head (8 × 8-dim):
"the cat sat on the mat"
├─> "the" ↔ pronouns
├─> "cat" ↔ animals
├─> "sat" ↔ verbs
└─> "on" ↔ prepositions
Each head specializes → emergent behavior
11. Full Transformer Block with Self-Attention
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-Attention + Residual
attn_out, attn_weights = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_out))
# Feed Forward + Residual
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
return x, attn_weights
12. Test: Multi-Head vs Single Head
x = torch.randn(1, 32, 512)
mha_8 = MultiHeadAttention(512, 8)
mha_1 = MultiHeadAttention(512, 1)
out_8, _ = mha_8(x, x, x)
out_1, _ = mha_1(x, x, x)
print("8 heads output norm:", out_8.norm().item())
print("1 head output norm:", out_1.norm().item())
8 heads → richer, more stable representations
13. Summary Cheat Sheet
| Concept | Value |
|---|---|
| Self-Attention | Q = K = V = X |
| Multi-Head | $ h $ parallel attention layers |
| Head Dim | $ d_k = d_{\text{model}} / h $ |
| Split | .view(B, N, h, d_k).transpose(1,2) |
| Combine | .transpose(1,2).view(B, N, d) |
| Parallelism | GPU runs all heads at once |
| Complexity | $ O(N^2 d) $ (same as single) |
14. Practice Exercises
- Ablate: Train with 1 vs 8 heads → compare performance on copy task.
- Visualize: Plot attention for each head on real sentences.
- Efficiency: Measure time for
num_heads=1, 8, 16. - Custom: Implement grouped-query attention (MQA).
- Debug: Add
print(shape)inforward()to trace tensor dims.
15. Key Takeaways
| Check | Insight |
|---|---|
| Check | Self-Attention = intra-sequence communication |
| Check | Multi-Head = parallel feature extractors |
| Check | Divide & Conquer = split embedding space |
| Check | Same cost, better performance |
| Check | Enables specialization |
Final Words
You just built the brain of every modern LLM.
Multi-Head Self-Attention = parallel, rich, scalable context.
Full Copy-Paste Code
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model, self.h, self.d_k = d_model, num_heads, d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = self.d_k ** 0.5
def forward(self, Q, K, V, mask=None):
B = Q.shape[0]
Q, K, V = self.W_q(Q), self.W_k(K), self.W_v(V)
Q = Q.view(B, -1, self.h, self.d_k).transpose(1, 2)
K = K.view(B, -1, self.h, self.d_k).transpose(1, 2)
V = V.view(B, -1, self.h, self.d_k).transpose(1, 2)
scores = (Q @ K.transpose(-2, -1)) / self.scale
if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))
attn = self.dropout(F.softmax(scores, dim=-1))
out = (attn @ V).transpose(1, 2).contiguous().view(B, -1, self.d_model)
return self.W_o(out), attn
End of Module
You now control parallel attention — the heart of GPT, BERT, and beyond.
Go stack 100 layers.
"Attention is All You Need" — Multi-Head & Self-Attention
Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch
"Attention is All You Need" — Multi-Head & Self-Attention
Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch
"Attention is All You Need" — Multi-Head & Self-Attention
Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch
Module Objective
Master Multi-Head Self-Attention — the core parallel computation engine of Transformers — with math, code, intuition, and divide-and-conquer parallelism.
1. Self-Attention: One Token Talks to All
Self-Attention = $ Q = K = V $
Every token attends to every other token in the same sequence.
# Input: [batch, seq_len, d_model]
X → Linear → Q, K, V → Attention(Q, K, V)
2. Why Multi-Head? Divide & Conquer
| Problem | Solution |
|---|---|
| One attention head = one perspective | Multiple heads = multiple subspaces |
| Risk of missing relations | Parallel views → richer representation |
"Let the model attend to information from different representation subspaces at different positions."
— Vaswani et al., 2017
3. Multi-Head Attention — The Formula
$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O
$$
$$
\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
$$
Where:
- $ h $ = number of heads
- $ d_k = d_v = d_{\text{model}} / h $
- $ W_i^Q, W_i^K, W_i^V \in \mathbb{R}^{d \times d_k} $
- $ W^O \in \mathbb{R}^{h d_v \times d} $
4. Step-by-Step: From Single to Multi-Head
| Step | Operation | Shape |
|---|---|---|
| 1 | Project $ X $ → $ Q, K, V $ | $ (B, N, d) $ |
| 2 | Split into $ h $ heads | $ (B, h, N, d/h) $ |
| 3 | Parallel attention | $ h $ heads → $ (B, h, N, d/h) $ |
| 4 | Concat + Linear | $ \to (B, N, d) $ |
5. Multi-Head Attention — From Scratch (PyTorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # d_v = d_k
# Learnable projections
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
self.scale = (self.d_k) ** 0.5
def split_heads(self, x):
"""Split last dim into (num_heads, d_k)"""
batch, seq_len, _ = x.shape
x = x.view(batch, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2) # (B, h, N, d_k)
def combine_heads(self, x):
"""Combine heads back to (B, N, d_model)"""
batch, _, seq_len, _ = x.shape
x = x.transpose(1, 2).contiguous()
return x.view(batch, seq_len, self.d_model)
def scaled_dot_product(self, Q, K, V, mask=None):
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # (B, h, N, N)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
return torch.matmul(attn, V), attn # (B, h, N, d_k), (B, h, N, N)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 1. Linear projections
Q = self.W_q(Q) # (B, N, d)
K = self.W_k(K)
V = self.W_v(V)
# 2. Split into heads
Q = self.split_heads(Q) # (B, h, N, d_k)
K = self.split_heads(K)
V = self.split_heads(V)
# 3. Apply attention in parallel
attn_output, attn_weights = self.scaled_dot_product(Q, K, V, mask)
# attn_output: (B, h, N, d_k)
# 4. Combine heads
output = self.combine_heads(attn_output) # (B, N, d)
# 5. Final linear
output = self.W_o(output)
return output, attn_weights
6. Self-Attention = Multi-Head(Q=X, K=X, V=X)
# Self-Attention
x = torch.randn(2, 10, 512) # (batch, seq_len, d_model)
mha = MultiHeadAttention(d_model=512, num_heads=8)
output, attn = mha(x, x, x) # Q=K=V=x
print(output.shape) # (2, 10, 512)
7. Parallelism: Divide & Conquer
Hardware View (GPU)
Input X → [Linear Q] → Split → [Head 1] → Attention
[Linear K] → Split → [Head 2] → Attention → Concat → W^O
[Linear V] → Split → [Head 3] → Attention
...
[Head 8] → Attention
All 8 heads run in parallel on GPU
Memory: $ O(h \cdot N^2) $ → still $ O(N^2) $, but richer features
8. Visualization: Multi-Head Attention Maps
import matplotlib.pyplot as plt
import seaborn as sns
# Dummy input
x = torch.randn(1, 5, 64)
mha = MultiHeadAttention(d_model=64, num_heads=8)
_, attn_weights = mha(x, x, x) # (B, h, N, N)
# Plot all heads
fig, axes = plt.subplots(2, 4, figsize=(16, 6))
axes = axes.flatten()
for i in range(8):
sns.heatmap(
attn_weights[0, i].detach().cpu(),
ax=axes[i],
cmap="viridis",
cbar=False
)
axes[i].set_title(f"Head {i+1}")
axes[i].set_xticks([])
axes[i].set_yticks([])
plt.suptitle("Multi-Head Attention Weights (8 Heads)", fontsize=16)
plt.tight_layout()
plt.show()
Each head learns different patterns:
- Head 1: Local
- Head 2: Global
- Head 3: Syntax
- etc.
9. Efficiency: Memory & Compute
| Operation | Time | Memory |
|---|---|---|
| Linear Projections | $ O(N d^2) $ | $ O(N d) $ |
| Split Heads | $ O(N d) $ | $ O(N d) $ |
| Attention (per head) | $ O(N^2 d/h) $ | $ O(N^2) $ |
| Total | $ O(N^2 d) $ | $ O(N^2 + N d) $ |
Same complexity as single head, but richer output
10. Divide & Conquer Intuition
Single Head (64-dim):
"the cat sat on the mat"
└────┬────┘
One view
Multi-Head (8 × 8-dim):
"the cat sat on the mat"
├─> "the" ↔ pronouns
├─> "cat" ↔ animals
├─> "sat" ↔ verbs
└─> "on" ↔ prepositions
Each head specializes → emergent behavior
11. Full Transformer Block with Self-Attention
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-Attention + Residual
attn_out, attn_weights = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_out))
# Feed Forward + Residual
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
return x, attn_weights
12. Test: Multi-Head vs Single Head
x = torch.randn(1, 32, 512)
mha_8 = MultiHeadAttention(512, 8)
mha_1 = MultiHeadAttention(512, 1)
out_8, _ = mha_8(x, x, x)
out_1, _ = mha_1(x, x, x)
print("8 heads output norm:", out_8.norm().item())
print("1 head output norm:", out_1.norm().item())
8 heads → richer, more stable representations
13. Summary Cheat Sheet
| Concept | Value |
|---|---|
| Self-Attention | Q = K = V = X |
| Multi-Head | $ h $ parallel attention layers |
| Head Dim | $ d_k = d_{\text{model}} / h $ |
| Split | .view(B, N, h, d_k).transpose(1,2) |
| Combine | .transpose(1,2).view(B, N, d) |
| Parallelism | GPU runs all heads at once |
| Complexity | $ O(N^2 d) $ (same as single) |
14. Practice Exercises
- Ablate: Train with 1 vs 8 heads → compare performance on copy task.
- Visualize: Plot attention for each head on real sentences.
- Efficiency: Measure time for
num_heads=1, 8, 16. - Custom: Implement grouped-query attention (MQA).
- Debug: Add
print(shape)inforward()to trace tensor dims.
15. Key Takeaways
| Check | Insight |
|---|---|
| Check | Self-Attention = intra-sequence communication |
| Check | Multi-Head = parallel feature extractors |
| Check | Divide & Conquer = split embedding space |
| Check | Same cost, better performance |
| Check | Enables specialization |
Final Words
You just built the brain of every modern LLM.
Multi-Head Self-Attention = parallel, rich, scalable context.
Full Copy-Paste Code
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model, self.h, self.d_k = d_model, num_heads, d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = self.d_k ** 0.5
def forward(self, Q, K, V, mask=None):
B = Q.shape[0]
Q, K, V = self.W_q(Q), self.W_k(K), self.W_v(V)
Q = Q.view(B, -1, self.h, self.d_k).transpose(1, 2)
K = K.view(B, -1, self.h, self.d_k).transpose(1, 2)
V = V.view(B, -1, self.h, self.d_k).transpose(1, 2)
scores = (Q @ K.transpose(-2, -1)) / self.scale
if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))
attn = self.dropout(F.softmax(scores, dim=-1))
out = (attn @ V).transpose(1, 2).contiguous().view(B, -1, self.d_model)
return self.W_o(out), attn
End of Module
You now control parallel attention — the heart of GPT, BERT, and beyond.
Go stack 100 layers.