"Attention is All You Need" — Multi-Head & Self-Attention

Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch

"Attention is All You Need" — Multi-Head & Self-Attention

Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch

"Attention is All You Need" — Multi-Head & Self-Attention

Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch


Module Objective

Master Multi-Head Self-Attention — the core parallel computation engine of Transformers — with math, code, intuition, and divide-and-conquer parallelism.


1. Self-Attention: One Token Talks to All

Self-Attention = $ Q = K = V $
Every token attends to every other token in the same sequence.

# Input: [batch, seq_len, d_model]
X  Linear  Q, K, V  Attention(Q, K, V)

2. Why Multi-Head? Divide & Conquer

Problem Solution
One attention head = one perspective Multiple heads = multiple subspaces
Risk of missing relations Parallel views → richer representation

"Let the model attend to information from different representation subspaces at different positions."
Vaswani et al., 2017


3. Multi-Head Attention — The Formula

$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O
$$

$$
\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
$$

Where:
- $ h $ = number of heads
- $ d_k = d_v = d_{\text{model}} / h $
- $ W_i^Q, W_i^K, W_i^V \in \mathbb{R}^{d \times d_k} $
- $ W^O \in \mathbb{R}^{h d_v \times d} $


4. Step-by-Step: From Single to Multi-Head

Step Operation Shape
1 Project $ X $ → $ Q, K, V $ $ (B, N, d) $
2 Split into $ h $ heads $ (B, h, N, d/h) $
3 Parallel attention $ h $ heads → $ (B, h, N, d/h) $
4 Concat + Linear $ \to (B, N, d) $

5. Multi-Head Attention — From Scratch (PyTorch)

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # d_v = d_k

        # Learnable projections
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.scale = (self.d_k) ** 0.5

    def split_heads(self, x):
        """Split last dim into (num_heads, d_k)"""
        batch, seq_len, _ = x.shape
        x = x.view(batch, seq_len, self.num_heads, self.d_k)
        return x.transpose(1, 2)  # (B, h, N, d_k)

    def combine_heads(self, x):
        """Combine heads back to (B, N, d_model)"""
        batch, _, seq_len, _ = x.shape
        x = x.transpose(1, 2).contiguous()
        return x.view(batch, seq_len, self.d_model)

    def scaled_dot_product(self, Q, K, V, mask=None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # (B, h, N, N)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        return torch.matmul(attn, V), attn  # (B, h, N, d_k), (B, h, N, N)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # 1. Linear projections
        Q = self.W_q(Q)  # (B, N, d)
        K = self.W_k(K)
        V = self.W_v(V)

        # 2. Split into heads
        Q = self.split_heads(Q)  # (B, h, N, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)

        # 3. Apply attention in parallel
        attn_output, attn_weights = self.scaled_dot_product(Q, K, V, mask)
        # attn_output: (B, h, N, d_k)

        # 4. Combine heads
        output = self.combine_heads(attn_output)  # (B, N, d)

        # 5. Final linear
        output = self.W_o(output)

        return output, attn_weights

6. Self-Attention = Multi-Head(Q=X, K=X, V=X)

# Self-Attention
x = torch.randn(2, 10, 512)  # (batch, seq_len, d_model)
mha = MultiHeadAttention(d_model=512, num_heads=8)
output, attn = mha(x, x, x)  # Q=K=V=x
print(output.shape)  # (2, 10, 512)

7. Parallelism: Divide & Conquer

Hardware View (GPU)

Input X → [Linear Q] → Split → [Head 1] → Attention
          [Linear K] → Split → [Head 2] → Attention → Concat → W^O
          [Linear V] → Split → [Head 3] → Attention
                             ...
                             [Head 8] → Attention

All 8 heads run in parallel on GPU
Memory: $ O(h \cdot N^2) $ → still $ O(N^2) $, but richer features


8. Visualization: Multi-Head Attention Maps

import matplotlib.pyplot as plt
import seaborn as sns

# Dummy input
x = torch.randn(1, 5, 64)
mha = MultiHeadAttention(d_model=64, num_heads=8)
_, attn_weights = mha(x, x, x)  # (B, h, N, N)

# Plot all heads
fig, axes = plt.subplots(2, 4, figsize=(16, 6))
axes = axes.flatten()

for i in range(8):
    sns.heatmap(
        attn_weights[0, i].detach().cpu(),
        ax=axes[i],
        cmap="viridis",
        cbar=False
    )
    axes[i].set_title(f"Head {i+1}")
    axes[i].set_xticks([])
    axes[i].set_yticks([])

plt.suptitle("Multi-Head Attention Weights (8 Heads)", fontsize=16)
plt.tight_layout()
plt.show()

Each head learns different patterns:
- Head 1: Local
- Head 2: Global
- Head 3: Syntax
- etc.


9. Efficiency: Memory & Compute

Operation Time Memory
Linear Projections $ O(N d^2) $ $ O(N d) $
Split Heads $ O(N d) $ $ O(N d) $
Attention (per head) $ O(N^2 d/h) $ $ O(N^2) $
Total $ O(N^2 d) $ $ O(N^2 + N d) $

Same complexity as single head, but richer output


10. Divide & Conquer Intuition

Single Head (64-dim):
"the cat sat on the mat"
       └────┬────┘
           One view

Multi-Head (8 × 8-dim):
"the cat sat on the mat"
 ├─> "the" ↔ pronouns
 ├─> "cat" ↔ animals
 ├─> "sat" ↔ verbs
 └─> "on" ↔ prepositions

Each head specializesemergent behavior


11. Full Transformer Block with Self-Attention

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-Attention + Residual
        attn_out, attn_weights = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_out))

        # Feed Forward + Residual
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))

        return x, attn_weights

12. Test: Multi-Head vs Single Head

x = torch.randn(1, 32, 512)

mha_8 = MultiHeadAttention(512, 8)
mha_1 = MultiHeadAttention(512, 1)

out_8, _ = mha_8(x, x, x)
out_1, _ = mha_1(x, x, x)

print("8 heads output norm:", out_8.norm().item())
print("1 head output norm:", out_1.norm().item())

8 heads → richer, more stable representations


13. Summary Cheat Sheet

Concept Value
Self-Attention Q = K = V = X
Multi-Head $ h $ parallel attention layers
Head Dim $ d_k = d_{\text{model}} / h $
Split .view(B, N, h, d_k).transpose(1,2)
Combine .transpose(1,2).view(B, N, d)
Parallelism GPU runs all heads at once
Complexity $ O(N^2 d) $ (same as single)

14. Practice Exercises

  1. Ablate: Train with 1 vs 8 heads → compare performance on copy task.
  2. Visualize: Plot attention for each head on real sentences.
  3. Efficiency: Measure time for num_heads=1, 8, 16.
  4. Custom: Implement grouped-query attention (MQA).
  5. Debug: Add print(shape) in forward() to trace tensor dims.

15. Key Takeaways

Check Insight
Check Self-Attention = intra-sequence communication
Check Multi-Head = parallel feature extractors
Check Divide & Conquer = split embedding space
Check Same cost, better performance
Check Enables specialization

Final Words

You just built the brain of every modern LLM.
Multi-Head Self-Attention = parallel, rich, scalable context.


Full Copy-Paste Code

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model, self.h, self.d_k = d_model, num_heads, d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.scale = self.d_k ** 0.5

    def forward(self, Q, K, V, mask=None):
        B = Q.shape[0]
        Q, K, V = self.W_q(Q), self.W_k(K), self.W_v(V)
        Q = Q.view(B, -1, self.h, self.d_k).transpose(1, 2)
        K = K.view(B, -1, self.h, self.d_k).transpose(1, 2)
        V = V.view(B, -1, self.h, self.d_k).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / self.scale
        if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = self.dropout(F.softmax(scores, dim=-1))
        out = (attn @ V).transpose(1, 2).contiguous().view(B, -1, self.d_model)
        return self.W_o(out), attn

End of Module
You now control parallel attention — the heart of GPT, BERT, and beyond.
Go stack 100 layers.

Last updated: Nov 13, 2025

"Attention is All You Need" — Multi-Head & Self-Attention

Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch

"Attention is All You Need" — Multi-Head & Self-Attention

Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch

"Attention is All You Need" — Multi-Head & Self-Attention

Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch


Module Objective

Master Multi-Head Self-Attention — the core parallel computation engine of Transformers — with math, code, intuition, and divide-and-conquer parallelism.


1. Self-Attention: One Token Talks to All

Self-Attention = $ Q = K = V $
Every token attends to every other token in the same sequence.

# Input: [batch, seq_len, d_model]
X  Linear  Q, K, V  Attention(Q, K, V)

2. Why Multi-Head? Divide & Conquer

Problem Solution
One attention head = one perspective Multiple heads = multiple subspaces
Risk of missing relations Parallel views → richer representation

"Let the model attend to information from different representation subspaces at different positions."
Vaswani et al., 2017


3. Multi-Head Attention — The Formula

$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O
$$

$$
\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
$$

Where:
- $ h $ = number of heads
- $ d_k = d_v = d_{\text{model}} / h $
- $ W_i^Q, W_i^K, W_i^V \in \mathbb{R}^{d \times d_k} $
- $ W^O \in \mathbb{R}^{h d_v \times d} $


4. Step-by-Step: From Single to Multi-Head

Step Operation Shape
1 Project $ X $ → $ Q, K, V $ $ (B, N, d) $
2 Split into $ h $ heads $ (B, h, N, d/h) $
3 Parallel attention $ h $ heads → $ (B, h, N, d/h) $
4 Concat + Linear $ \to (B, N, d) $

5. Multi-Head Attention — From Scratch (PyTorch)

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # d_v = d_k

        # Learnable projections
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.scale = (self.d_k) ** 0.5

    def split_heads(self, x):
        """Split last dim into (num_heads, d_k)"""
        batch, seq_len, _ = x.shape
        x = x.view(batch, seq_len, self.num_heads, self.d_k)
        return x.transpose(1, 2)  # (B, h, N, d_k)

    def combine_heads(self, x):
        """Combine heads back to (B, N, d_model)"""
        batch, _, seq_len, _ = x.shape
        x = x.transpose(1, 2).contiguous()
        return x.view(batch, seq_len, self.d_model)

    def scaled_dot_product(self, Q, K, V, mask=None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # (B, h, N, N)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        return torch.matmul(attn, V), attn  # (B, h, N, d_k), (B, h, N, N)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # 1. Linear projections
        Q = self.W_q(Q)  # (B, N, d)
        K = self.W_k(K)
        V = self.W_v(V)

        # 2. Split into heads
        Q = self.split_heads(Q)  # (B, h, N, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)

        # 3. Apply attention in parallel
        attn_output, attn_weights = self.scaled_dot_product(Q, K, V, mask)
        # attn_output: (B, h, N, d_k)

        # 4. Combine heads
        output = self.combine_heads(attn_output)  # (B, N, d)

        # 5. Final linear
        output = self.W_o(output)

        return output, attn_weights

6. Self-Attention = Multi-Head(Q=X, K=X, V=X)

# Self-Attention
x = torch.randn(2, 10, 512)  # (batch, seq_len, d_model)
mha = MultiHeadAttention(d_model=512, num_heads=8)
output, attn = mha(x, x, x)  # Q=K=V=x
print(output.shape)  # (2, 10, 512)

7. Parallelism: Divide & Conquer

Hardware View (GPU)

Input X → [Linear Q] → Split → [Head 1] → Attention
          [Linear K] → Split → [Head 2] → Attention → Concat → W^O
          [Linear V] → Split → [Head 3] → Attention
                             ...
                             [Head 8] → Attention

All 8 heads run in parallel on GPU
Memory: $ O(h \cdot N^2) $ → still $ O(N^2) $, but richer features


8. Visualization: Multi-Head Attention Maps

import matplotlib.pyplot as plt
import seaborn as sns

# Dummy input
x = torch.randn(1, 5, 64)
mha = MultiHeadAttention(d_model=64, num_heads=8)
_, attn_weights = mha(x, x, x)  # (B, h, N, N)

# Plot all heads
fig, axes = plt.subplots(2, 4, figsize=(16, 6))
axes = axes.flatten()

for i in range(8):
    sns.heatmap(
        attn_weights[0, i].detach().cpu(),
        ax=axes[i],
        cmap="viridis",
        cbar=False
    )
    axes[i].set_title(f"Head {i+1}")
    axes[i].set_xticks([])
    axes[i].set_yticks([])

plt.suptitle("Multi-Head Attention Weights (8 Heads)", fontsize=16)
plt.tight_layout()
plt.show()

Each head learns different patterns:
- Head 1: Local
- Head 2: Global
- Head 3: Syntax
- etc.


9. Efficiency: Memory & Compute

Operation Time Memory
Linear Projections $ O(N d^2) $ $ O(N d) $
Split Heads $ O(N d) $ $ O(N d) $
Attention (per head) $ O(N^2 d/h) $ $ O(N^2) $
Total $ O(N^2 d) $ $ O(N^2 + N d) $

Same complexity as single head, but richer output


10. Divide & Conquer Intuition

Single Head (64-dim):
"the cat sat on the mat"
       └────┬────┘
           One view

Multi-Head (8 × 8-dim):
"the cat sat on the mat"
 ├─> "the" ↔ pronouns
 ├─> "cat" ↔ animals
 ├─> "sat" ↔ verbs
 └─> "on" ↔ prepositions

Each head specializesemergent behavior


11. Full Transformer Block with Self-Attention

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-Attention + Residual
        attn_out, attn_weights = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_out))

        # Feed Forward + Residual
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))

        return x, attn_weights

12. Test: Multi-Head vs Single Head

x = torch.randn(1, 32, 512)

mha_8 = MultiHeadAttention(512, 8)
mha_1 = MultiHeadAttention(512, 1)

out_8, _ = mha_8(x, x, x)
out_1, _ = mha_1(x, x, x)

print("8 heads output norm:", out_8.norm().item())
print("1 head output norm:", out_1.norm().item())

8 heads → richer, more stable representations


13. Summary Cheat Sheet

Concept Value
Self-Attention Q = K = V = X
Multi-Head $ h $ parallel attention layers
Head Dim $ d_k = d_{\text{model}} / h $
Split .view(B, N, h, d_k).transpose(1,2)
Combine .transpose(1,2).view(B, N, d)
Parallelism GPU runs all heads at once
Complexity $ O(N^2 d) $ (same as single)

14. Practice Exercises

  1. Ablate: Train with 1 vs 8 heads → compare performance on copy task.
  2. Visualize: Plot attention for each head on real sentences.
  3. Efficiency: Measure time for num_heads=1, 8, 16.
  4. Custom: Implement grouped-query attention (MQA).
  5. Debug: Add print(shape) in forward() to trace tensor dims.

15. Key Takeaways

Check Insight
Check Self-Attention = intra-sequence communication
Check Multi-Head = parallel feature extractors
Check Divide & Conquer = split embedding space
Check Same cost, better performance
Check Enables specialization

Final Words

You just built the brain of every modern LLM.
Multi-Head Self-Attention = parallel, rich, scalable context.


Full Copy-Paste Code

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model, self.h, self.d_k = d_model, num_heads, d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.scale = self.d_k ** 0.5

    def forward(self, Q, K, V, mask=None):
        B = Q.shape[0]
        Q, K, V = self.W_q(Q), self.W_k(K), self.W_v(V)
        Q = Q.view(B, -1, self.h, self.d_k).transpose(1, 2)
        K = K.view(B, -1, self.h, self.d_k).transpose(1, 2)
        V = V.view(B, -1, self.h, self.d_k).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / self.scale
        if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = self.dropout(F.softmax(scores, dim=-1))
        out = (attn @ V).transpose(1, 2).contiguous().view(B, -1, self.d_model)
        return self.W_o(out), attn

End of Module
You now control parallel attention — the heart of GPT, BERT, and beyond.
Go stack 100 layers.

Last updated: Nov 13, 2025