"Attention is All You Need" — Build Scaled Dot-Product Attention from Scratch
Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (hashing for large inputs).
"Attention is All You Need" — Build Scaled Dot-Product Attention from Scratch
Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (ha
"Attention is All You Need" — Build Scaled Dot-Product Attention from Scratch
A Complete One-Module Learning Tutorial with Graphs, Hashing, and Code
Module Objective
Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (hashing for large inputs).
1. Core Idea: Why Attention?
"Let every token talk to every other token — weighted by relevance."
Instead of RNNs or CNNs, Attention computes direct dependencies between input tokens.
2. Scaled Dot-Product Attention — The Formula
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$
Where:
- $ Q \in \mathbb{R}^{n \times d_k} $: Queries
- $ K \in \mathbb{R}^{m \times d_k} $: Keys
- $ V \in \mathbb{R}^{m \times d_v} $: Values
- $ d_k $: dimension of keys/queries
- $ n $: number of queries (e.g., output sequence length)
- $ m $: number of keys/values (e.g., input sequence length)
3. Step-by-Step Breakdown
| Step | Operation | Shape |
|---|---|---|
| 1 | $ QK^T $ | $ (n, d_k) \times (d_k, m) \to (n, m) $ |
| 2 | Scale: $ \div \sqrt{d_k} $ | Stabilizes gradients |
| 3 | Softmax over last dim | $ \to $ attention weights |
| 4 | Multiply by $ V $ | $ (n, m) \times (m, d_v) \to (n, d_v) $ |
4. PyTorch Implementation (From Scratch)
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: (batch, n, d_k)
K: (batch, m, d_k)
V: (batch, m, d_v)
"""
d_k = Q.size(-1)
# Step 1: QK^T
scores = torch.matmul(Q, K.transpose(-2, -1)) # (batch, n, m)
# Step 2: Scale
scores = scores / (d_k ** 0.5)
# Step 3: Optional Mask (for decoder)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 4: Softmax
attn_weights = F.softmax(scores, dim=-1)
# Step 5: Weighted sum of values
output = torch.matmul(attn_weights, V) # (batch, n, d_v)
return output, attn_weights
5. Test with Dummy Data
batch_size = 1
seq_len = 4
d_k = d_v = 8
# Simulate learned projections
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_v)
output, attn = scaled_dot_product_attention(Q, K, V)
print("Output shape:", output.shape) # (1, 4, 8)
print("Attention weights shape:", attn.shape) # (1, 4, 4)
6. Visualize Attention Weights
def plot_attention(attn_weights, title="Attention Weights"):
plt.figure(figsize=(6, 5))
sns.heatmap(
attn_weights[0].detach().cpu().numpy(),
cmap="Blues",
annot=True,
fmt=".2f",
xticklabels=[f"Key {i}" for i in range(seq_len)],
yticklabels=[f"Query {i}" for i in range(seq_len)]
)
plt.title(title)
plt.xlabel("Keys")
plt.ylabel("Queries")
plt.show()
plot_attention(attn, "Random Attention (Before Training)")
After training, attention becomes sharp and meaningful (e.g., "it" → "cat").
7. Add Causal Mask (Decoder-Only)
def create_causal_mask(seq_len):
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
return mask == 0 # True where allowed
mask = create_causal_mask(seq_len)
mask = mask.unsqueeze(0) # (1, seq_len, seq_len)
output_masked, attn_masked = scaled_dot_product_attention(Q, K, V, mask=mask)
plot_attention(attn_masked, "Causal (Autoregressive) Attention")
Prevents future peeking — essential for language generation.
8. Efficiency Problem: $ O(n^2) $ Memory & Time
| Sequence Length | Memory (GB) | Time |
|---|---|---|
| 512 | ~0.5 GB | Fast |
| 4096 | ~16 GB | Slow |
| 32768 | ~1 TB | Impossible |
9. Optimization: Hashing + Sparse Attention
Idea: Locality-Sensitive Hashing (LSH) for Attention
Only attend to nearby or similar keys → reduce $ O(n^2) \to O(n \log n) $
LSH Attention (Reformer-style)
import torch.nn as nn
class LSHAttention(nn.Module):
def __init__(self, d_model, n_hashes=4, bucket_size=64):
super().__init__()
self.n_hashes = n_hashes
self.bucket_size = bucket_size
self.d_model = d_model
def hash_vectors(self, vectors):
# Random rotation + bucket
rotation_matrix = torch.randn(self.d_model, self.d_model)
rotated_vecs = vectors @ rotation_matrix
buckets = torch.argmax(rotated_vecs, dim=-1)
return buckets
def forward(self, Q, K, V):
batch_size, seq_len, d = Q.shape
# Multi-round LSH
all_outputs = []
all_weights = []
for _ in range(self.n_hashes):
buckets = self.hash_vectors(K) # (batch, seq_len)
sorted buckets, indices = torch.sort(buckets)
# Chunk into buckets
chunks = torch.split(indices, self.bucket_size, dim=1)
# Approximate attention within chunks
chunk_outs = []
for chunk in chunks:
Q_chunk = Q.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))
K_chunk = K.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))
V_chunk = V.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))
out, w = scaled_dot_product_attention(Q_chunk, K_chunk, V_chunk)
chunk_outs.append(out)
output = torch.cat(chunk_outs, dim=1)
all_outputs.append(output.unsqueeze(1))
final_output = torch.mean(torch.cat(all_outputs, dim=1), dim=1)
return final_output, None # weights not meaningful
Memory: $ O(n \cdot b) $ where $ b $ = bucket size
Used in: Reformer, Longformer, BigBird
10. Full Multi-Head Attention (Transformer Block)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x):
batch, seq, _ = x.shape
return x.view(batch, seq, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
batch, _, seq, d_k = x.shape
return x.transpose(1, 2).contiguous().view(batch, seq, self.d_model)
def forward(self, Q, K, V, mask=None):
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
output = self.combine_heads(attn_output)
return self.W_o(output), attn_weights
11. Full Example: Train Tiny Model
# Tiny dataset: learn to copy input
X = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8]], dtype=torch.long)
Y = X.clone()
model = MultiHeadAttention(d_model=16, num_heads=4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
# Embed (simple)
emb = nn.Embedding(10, 16)
x = emb(X)
output, _ = model(x, x, x)
loss = F.mse_loss(output, x)
loss.backward()
optimizer.step()
if epoch % 20 == 0:
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
12. Summary Cheat Sheet
| Component | Formula | Purpose |
|---|---|---|
| Dot Product | $ QK^T $ | Similarity |
| Scaling | $ \div \sqrt{d_k} $ | Gradient stability |
| Softmax | $ \text{softmax}(\cdot) $ |
| Model | Attention Type | Efficiency |
|---|---|---|
| Transformer | Full $ O(n^2) $ | Baseline |
| Reformer | LSH | $ O(n \log n) $ |
| Longformer | Sliding Window | $ O(n) $ |
13. Graph: Attention Scaling
import numpy as np
seq_lens = [128, 512, 2048, 8192, 32768]
full_mem = np.array(seq_lens)**2 * 4 / 1e9 # GB (float32)
lsh_mem = np.array(seq_lens) * 64 * 4 / 1e9 # bucket_size=64
plt.figure(figsize=(8, 5))
plt.plot(seq_lens, full_mem, 'r-o', label="Full Attention (O(n²))")
plt.plot(seq_lens, lsh_mem, 'g--s', label="LSH Attention (O(n log n))")
plt.yscale('log')
plt.xlabel("Sequence Length")
plt.ylabel("Memory (GB)")
plt.title("Attention Memory Scaling")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Practice Exercises
- Implement masked multi-head attention for a decoder.
- Replace softmax with sparsemax or entmax.
- Add relative position encodings.
- Use Performer's FAVOR+ (linear attention).
- Visualize attention on real sentences using Hugging Face.
Key Takeaways
| Check | Insight |
|---|---|
| Check | Attention = weighted sum of values, guided by query-key similarity |
| Check | Scaling prevents vanishing gradients |
| Check | Causal mask → autoregressive generation |
| Check | Hashing (LSH) → long sequences |
| Check | Multi-head → multiple perspectives |
Final Words
"Attention is All You Need" — not just a paper, but a paradigm shift.
You now have:
- Full mathematical understanding
- Working PyTorch code
- Visualization tools
- Efficiency tricks (hashing)
Next Steps:
Build a mini-Transformer from scratch → train on text → generate poetry!
End of Module
You just built the heart of GPT, BERT, and every modern LLM.
Attention is yours.
"Attention is All You Need" — Build Scaled Dot-Product Attention from Scratch
Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (hashing for large inputs).
"Attention is All You Need" — Build Scaled Dot-Product Attention from Scratch
Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (ha
"Attention is All You Need" — Build Scaled Dot-Product Attention from Scratch
A Complete One-Module Learning Tutorial with Graphs, Hashing, and Code
Module Objective
Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (hashing for large inputs).
1. Core Idea: Why Attention?
"Let every token talk to every other token — weighted by relevance."
Instead of RNNs or CNNs, Attention computes direct dependencies between input tokens.
2. Scaled Dot-Product Attention — The Formula
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$
Where:
- $ Q \in \mathbb{R}^{n \times d_k} $: Queries
- $ K \in \mathbb{R}^{m \times d_k} $: Keys
- $ V \in \mathbb{R}^{m \times d_v} $: Values
- $ d_k $: dimension of keys/queries
- $ n $: number of queries (e.g., output sequence length)
- $ m $: number of keys/values (e.g., input sequence length)
3. Step-by-Step Breakdown
| Step | Operation | Shape |
|---|---|---|
| 1 | $ QK^T $ | $ (n, d_k) \times (d_k, m) \to (n, m) $ |
| 2 | Scale: $ \div \sqrt{d_k} $ | Stabilizes gradients |
| 3 | Softmax over last dim | $ \to $ attention weights |
| 4 | Multiply by $ V $ | $ (n, m) \times (m, d_v) \to (n, d_v) $ |
4. PyTorch Implementation (From Scratch)
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: (batch, n, d_k)
K: (batch, m, d_k)
V: (batch, m, d_v)
"""
d_k = Q.size(-1)
# Step 1: QK^T
scores = torch.matmul(Q, K.transpose(-2, -1)) # (batch, n, m)
# Step 2: Scale
scores = scores / (d_k ** 0.5)
# Step 3: Optional Mask (for decoder)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 4: Softmax
attn_weights = F.softmax(scores, dim=-1)
# Step 5: Weighted sum of values
output = torch.matmul(attn_weights, V) # (batch, n, d_v)
return output, attn_weights
5. Test with Dummy Data
batch_size = 1
seq_len = 4
d_k = d_v = 8
# Simulate learned projections
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_v)
output, attn = scaled_dot_product_attention(Q, K, V)
print("Output shape:", output.shape) # (1, 4, 8)
print("Attention weights shape:", attn.shape) # (1, 4, 4)
6. Visualize Attention Weights
def plot_attention(attn_weights, title="Attention Weights"):
plt.figure(figsize=(6, 5))
sns.heatmap(
attn_weights[0].detach().cpu().numpy(),
cmap="Blues",
annot=True,
fmt=".2f",
xticklabels=[f"Key {i}" for i in range(seq_len)],
yticklabels=[f"Query {i}" for i in range(seq_len)]
)
plt.title(title)
plt.xlabel("Keys")
plt.ylabel("Queries")
plt.show()
plot_attention(attn, "Random Attention (Before Training)")
After training, attention becomes sharp and meaningful (e.g., "it" → "cat").
7. Add Causal Mask (Decoder-Only)
def create_causal_mask(seq_len):
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
return mask == 0 # True where allowed
mask = create_causal_mask(seq_len)
mask = mask.unsqueeze(0) # (1, seq_len, seq_len)
output_masked, attn_masked = scaled_dot_product_attention(Q, K, V, mask=mask)
plot_attention(attn_masked, "Causal (Autoregressive) Attention")
Prevents future peeking — essential for language generation.
8. Efficiency Problem: $ O(n^2) $ Memory & Time
| Sequence Length | Memory (GB) | Time |
|---|---|---|
| 512 | ~0.5 GB | Fast |
| 4096 | ~16 GB | Slow |
| 32768 | ~1 TB | Impossible |
9. Optimization: Hashing + Sparse Attention
Idea: Locality-Sensitive Hashing (LSH) for Attention
Only attend to nearby or similar keys → reduce $ O(n^2) \to O(n \log n) $
LSH Attention (Reformer-style)
import torch.nn as nn
class LSHAttention(nn.Module):
def __init__(self, d_model, n_hashes=4, bucket_size=64):
super().__init__()
self.n_hashes = n_hashes
self.bucket_size = bucket_size
self.d_model = d_model
def hash_vectors(self, vectors):
# Random rotation + bucket
rotation_matrix = torch.randn(self.d_model, self.d_model)
rotated_vecs = vectors @ rotation_matrix
buckets = torch.argmax(rotated_vecs, dim=-1)
return buckets
def forward(self, Q, K, V):
batch_size, seq_len, d = Q.shape
# Multi-round LSH
all_outputs = []
all_weights = []
for _ in range(self.n_hashes):
buckets = self.hash_vectors(K) # (batch, seq_len)
sorted buckets, indices = torch.sort(buckets)
# Chunk into buckets
chunks = torch.split(indices, self.bucket_size, dim=1)
# Approximate attention within chunks
chunk_outs = []
for chunk in chunks:
Q_chunk = Q.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))
K_chunk = K.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))
V_chunk = V.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))
out, w = scaled_dot_product_attention(Q_chunk, K_chunk, V_chunk)
chunk_outs.append(out)
output = torch.cat(chunk_outs, dim=1)
all_outputs.append(output.unsqueeze(1))
final_output = torch.mean(torch.cat(all_outputs, dim=1), dim=1)
return final_output, None # weights not meaningful
Memory: $ O(n \cdot b) $ where $ b $ = bucket size
Used in: Reformer, Longformer, BigBird
10. Full Multi-Head Attention (Transformer Block)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x):
batch, seq, _ = x.shape
return x.view(batch, seq, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
batch, _, seq, d_k = x.shape
return x.transpose(1, 2).contiguous().view(batch, seq, self.d_model)
def forward(self, Q, K, V, mask=None):
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
output = self.combine_heads(attn_output)
return self.W_o(output), attn_weights
11. Full Example: Train Tiny Model
# Tiny dataset: learn to copy input
X = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8]], dtype=torch.long)
Y = X.clone()
model = MultiHeadAttention(d_model=16, num_heads=4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
# Embed (simple)
emb = nn.Embedding(10, 16)
x = emb(X)
output, _ = model(x, x, x)
loss = F.mse_loss(output, x)
loss.backward()
optimizer.step()
if epoch % 20 == 0:
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
12. Summary Cheat Sheet
| Component | Formula | Purpose |
|---|---|---|
| Dot Product | $ QK^T $ | Similarity |
| Scaling | $ \div \sqrt{d_k} $ | Gradient stability |
| Softmax | $ \text{softmax}(\cdot) $ |
| Model | Attention Type | Efficiency |
|---|---|---|
| Transformer | Full $ O(n^2) $ | Baseline |
| Reformer | LSH | $ O(n \log n) $ |
| Longformer | Sliding Window | $ O(n) $ |
13. Graph: Attention Scaling
import numpy as np
seq_lens = [128, 512, 2048, 8192, 32768]
full_mem = np.array(seq_lens)**2 * 4 / 1e9 # GB (float32)
lsh_mem = np.array(seq_lens) * 64 * 4 / 1e9 # bucket_size=64
plt.figure(figsize=(8, 5))
plt.plot(seq_lens, full_mem, 'r-o', label="Full Attention (O(n²))")
plt.plot(seq_lens, lsh_mem, 'g--s', label="LSH Attention (O(n log n))")
plt.yscale('log')
plt.xlabel("Sequence Length")
plt.ylabel("Memory (GB)")
plt.title("Attention Memory Scaling")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Practice Exercises
- Implement masked multi-head attention for a decoder.
- Replace softmax with sparsemax or entmax.
- Add relative position encodings.
- Use Performer's FAVOR+ (linear attention).
- Visualize attention on real sentences using Hugging Face.
Key Takeaways
| Check | Insight |
|---|---|
| Check | Attention = weighted sum of values, guided by query-key similarity |
| Check | Scaling prevents vanishing gradients |
| Check | Causal mask → autoregressive generation |
| Check | Hashing (LSH) → long sequences |
| Check | Multi-head → multiple perspectives |
Final Words
"Attention is All You Need" — not just a paper, but a paradigm shift.
You now have:
- Full mathematical understanding
- Working PyTorch code
- Visualization tools
- Efficiency tricks (hashing)
Next Steps:
Build a mini-Transformer from scratch → train on text → generate poetry!
End of Module
You just built the heart of GPT, BERT, and every modern LLM.
Attention is yours.