"Attention is All You Need" — Build Scaled Dot-Product Attention from Scratch

Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (hashing for large inputs).

"Attention is All You Need" — Build Scaled Dot-Product Attention from Scratch

Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (ha

"Attention is All You Need" — Build Scaled Dot-Product Attention from Scratch

A Complete One-Module Learning Tutorial with Graphs, Hashing, and Code


Module Objective

Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (hashing for large inputs).


1. Core Idea: Why Attention?

"Let every token talk to every other token — weighted by relevance."

Instead of RNNs or CNNs, Attention computes direct dependencies between input tokens.


2. Scaled Dot-Product Attention — The Formula

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

Where:
- $ Q \in \mathbb{R}^{n \times d_k} $: Queries
- $ K \in \mathbb{R}^{m \times d_k} $: Keys
- $ V \in \mathbb{R}^{m \times d_v} $: Values
- $ d_k $: dimension of keys/queries
- $ n $: number of queries (e.g., output sequence length)
- $ m $: number of keys/values (e.g., input sequence length)


3. Step-by-Step Breakdown

Step Operation Shape
1 $ QK^T $ $ (n, d_k) \times (d_k, m) \to (n, m) $
2 Scale: $ \div \sqrt{d_k} $ Stabilizes gradients
3 Softmax over last dim $ \to $ attention weights
4 Multiply by $ V $ $ (n, m) \times (m, d_v) \to (n, d_v) $

4. PyTorch Implementation (From Scratch)

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: (batch, n, d_k)
    K: (batch, m, d_k)
    V: (batch, m, d_v)
    """
    d_k = Q.size(-1)

    # Step 1: QK^T
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, n, m)

    # Step 2: Scale
    scores = scores / (d_k ** 0.5)

    # Step 3: Optional Mask (for decoder)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 4: Softmax
    attn_weights = F.softmax(scores, dim=-1)

    # Step 5: Weighted sum of values
    output = torch.matmul(attn_weights, V)  # (batch, n, d_v)

    return output, attn_weights

5. Test with Dummy Data

batch_size = 1
seq_len = 4
d_k = d_v = 8

# Simulate learned projections
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_v)

output, attn = scaled_dot_product_attention(Q, K, V)

print("Output shape:", output.shape)        # (1, 4, 8)
print("Attention weights shape:", attn.shape)  # (1, 4, 4)

6. Visualize Attention Weights

def plot_attention(attn_weights, title="Attention Weights"):
    plt.figure(figsize=(6, 5))
    sns.heatmap(
        attn_weights[0].detach().cpu().numpy(),
        cmap="Blues",
        annot=True,
        fmt=".2f",
        xticklabels=[f"Key {i}" for i in range(seq_len)],
        yticklabels=[f"Query {i}" for i in range(seq_len)]
    )
    plt.title(title)
    plt.xlabel("Keys")
    plt.ylabel("Queries")
    plt.show()

plot_attention(attn, "Random Attention (Before Training)")

After training, attention becomes sharp and meaningful (e.g., "it" → "cat").


7. Add Causal Mask (Decoder-Only)

def create_causal_mask(seq_len):
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    return mask == 0  # True where allowed

mask = create_causal_mask(seq_len)
mask = mask.unsqueeze(0)  # (1, seq_len, seq_len)

output_masked, attn_masked = scaled_dot_product_attention(Q, K, V, mask=mask)
plot_attention(attn_masked, "Causal (Autoregressive) Attention")

Prevents future peeking — essential for language generation.


8. Efficiency Problem: $ O(n^2) $ Memory & Time

Sequence Length Memory (GB) Time
512 ~0.5 GB Fast
4096 ~16 GB Slow
32768 ~1 TB Impossible

9. Optimization: Hashing + Sparse Attention

Idea: Locality-Sensitive Hashing (LSH) for Attention

Only attend to nearby or similar keys → reduce $ O(n^2) \to O(n \log n) $

LSH Attention (Reformer-style)

import torch.nn as nn

class LSHAttention(nn.Module):
    def __init__(self, d_model, n_hashes=4, bucket_size=64):
        super().__init__()
        self.n_hashes = n_hashes
        self.bucket_size = bucket_size
        self.d_model = d_model

    def hash_vectors(self, vectors):
        # Random rotation + bucket
        rotation_matrix = torch.randn(self.d_model, self.d_model)
        rotated_vecs = vectors @ rotation_matrix
        buckets = torch.argmax(rotated_vecs, dim=-1)
        return buckets

    def forward(self, Q, K, V):
        batch_size, seq_len, d = Q.shape

        # Multi-round LSH
        all_outputs = []
        all_weights = []

        for _ in range(self.n_hashes):
            buckets = self.hash_vectors(K)  # (batch, seq_len)
            sorted buckets, indices = torch.sort(buckets)

            # Chunk into buckets
            chunks = torch.split(indices, self.bucket_size, dim=1)

            # Approximate attention within chunks
            chunk_outs = []
            for chunk in chunks:
                Q_chunk = Q.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))
                K_chunk = K.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))
                V_chunk = V.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))

                out, w = scaled_dot_product_attention(Q_chunk, K_chunk, V_chunk)
                chunk_outs.append(out)

            output = torch.cat(chunk_outs, dim=1)
            all_outputs.append(output.unsqueeze(1))

        final_output = torch.mean(torch.cat(all_outputs, dim=1), dim=1)
        return final_output, None  # weights not meaningful

Memory: $ O(n \cdot b) $ where $ b $ = bucket size
Used in: Reformer, Longformer, BigBird


10. Full Multi-Head Attention (Transformer Block)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        batch, seq, _ = x.shape
        return x.view(batch, seq, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        batch, _, seq, d_k = x.shape
        return x.transpose(1, 2).contiguous().view(batch, seq, self.d_model)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        output = self.combine_heads(attn_output)
        return self.W_o(output), attn_weights

11. Full Example: Train Tiny Model

# Tiny dataset: learn to copy input
X = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8]], dtype=torch.long)
Y = X.clone()

model = MultiHeadAttention(d_model=16, num_heads=4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    optimizer.zero_grad()

    # Embed (simple)
    emb = nn.Embedding(10, 16)
    x = emb(X)

    output, _ = model(x, x, x)
    loss = F.mse_loss(output, x)

    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

12. Summary Cheat Sheet

Component Formula Purpose
Dot Product $ QK^T $ Similarity
Scaling $ \div \sqrt{d_k} $ Gradient stability
Softmax $ \text{softmax}(\cdot) $
Model Attention Type Efficiency
Transformer Full $ O(n^2) $ Baseline
Reformer LSH $ O(n \log n) $
Longformer Sliding Window $ O(n) $

13. Graph: Attention Scaling

import numpy as np

seq_lens = [128, 512, 2048, 8192, 32768]
full_mem = np.array(seq_lens)**2 * 4 / 1e9  # GB (float32)
lsh_mem = np.array(seq_lens) * 64 * 4 / 1e9  # bucket_size=64

plt.figure(figsize=(8, 5))
plt.plot(seq_lens, full_mem, 'r-o', label="Full Attention (O(n²))")
plt.plot(seq_lens, lsh_mem, 'g--s', label="LSH Attention (O(n log n))")
plt.yscale('log')
plt.xlabel("Sequence Length")
plt.ylabel("Memory (GB)")
plt.title("Attention Memory Scaling")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Practice Exercises

  1. Implement masked multi-head attention for a decoder.
  2. Replace softmax with sparsemax or entmax.
  3. Add relative position encodings.
  4. Use Performer's FAVOR+ (linear attention).
  5. Visualize attention on real sentences using Hugging Face.

Key Takeaways

Check Insight
Check Attention = weighted sum of values, guided by query-key similarity
Check Scaling prevents vanishing gradients
Check Causal mask → autoregressive generation
Check Hashing (LSH) → long sequences
Check Multi-head → multiple perspectives

Final Words

"Attention is All You Need" — not just a paper, but a paradigm shift.

You now have:
- Full mathematical understanding
- Working PyTorch code
- Visualization tools
- Efficiency tricks (hashing)


Next Steps:
Build a mini-Transformer from scratch → train on text → generate poetry!


End of Module
You just built the heart of GPT, BERT, and every modern LLM.
Attention is yours.

Last updated: Nov 13, 2025

"Attention is All You Need" — Build Scaled Dot-Product Attention from Scratch

Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (hashing for large inputs).

"Attention is All You Need" — Build Scaled Dot-Product Attention from Scratch

Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (ha

"Attention is All You Need" — Build Scaled Dot-Product Attention from Scratch

A Complete One-Module Learning Tutorial with Graphs, Hashing, and Code


Module Objective

Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (hashing for large inputs).


1. Core Idea: Why Attention?

"Let every token talk to every other token — weighted by relevance."

Instead of RNNs or CNNs, Attention computes direct dependencies between input tokens.


2. Scaled Dot-Product Attention — The Formula

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

Where:
- $ Q \in \mathbb{R}^{n \times d_k} $: Queries
- $ K \in \mathbb{R}^{m \times d_k} $: Keys
- $ V \in \mathbb{R}^{m \times d_v} $: Values
- $ d_k $: dimension of keys/queries
- $ n $: number of queries (e.g., output sequence length)
- $ m $: number of keys/values (e.g., input sequence length)


3. Step-by-Step Breakdown

Step Operation Shape
1 $ QK^T $ $ (n, d_k) \times (d_k, m) \to (n, m) $
2 Scale: $ \div \sqrt{d_k} $ Stabilizes gradients
3 Softmax over last dim $ \to $ attention weights
4 Multiply by $ V $ $ (n, m) \times (m, d_v) \to (n, d_v) $

4. PyTorch Implementation (From Scratch)

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: (batch, n, d_k)
    K: (batch, m, d_k)
    V: (batch, m, d_v)
    """
    d_k = Q.size(-1)

    # Step 1: QK^T
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, n, m)

    # Step 2: Scale
    scores = scores / (d_k ** 0.5)

    # Step 3: Optional Mask (for decoder)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 4: Softmax
    attn_weights = F.softmax(scores, dim=-1)

    # Step 5: Weighted sum of values
    output = torch.matmul(attn_weights, V)  # (batch, n, d_v)

    return output, attn_weights

5. Test with Dummy Data

batch_size = 1
seq_len = 4
d_k = d_v = 8

# Simulate learned projections
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_v)

output, attn = scaled_dot_product_attention(Q, K, V)

print("Output shape:", output.shape)        # (1, 4, 8)
print("Attention weights shape:", attn.shape)  # (1, 4, 4)

6. Visualize Attention Weights

def plot_attention(attn_weights, title="Attention Weights"):
    plt.figure(figsize=(6, 5))
    sns.heatmap(
        attn_weights[0].detach().cpu().numpy(),
        cmap="Blues",
        annot=True,
        fmt=".2f",
        xticklabels=[f"Key {i}" for i in range(seq_len)],
        yticklabels=[f"Query {i}" for i in range(seq_len)]
    )
    plt.title(title)
    plt.xlabel("Keys")
    plt.ylabel("Queries")
    plt.show()

plot_attention(attn, "Random Attention (Before Training)")

After training, attention becomes sharp and meaningful (e.g., "it" → "cat").


7. Add Causal Mask (Decoder-Only)

def create_causal_mask(seq_len):
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    return mask == 0  # True where allowed

mask = create_causal_mask(seq_len)
mask = mask.unsqueeze(0)  # (1, seq_len, seq_len)

output_masked, attn_masked = scaled_dot_product_attention(Q, K, V, mask=mask)
plot_attention(attn_masked, "Causal (Autoregressive) Attention")

Prevents future peeking — essential for language generation.


8. Efficiency Problem: $ O(n^2) $ Memory & Time

Sequence Length Memory (GB) Time
512 ~0.5 GB Fast
4096 ~16 GB Slow
32768 ~1 TB Impossible

9. Optimization: Hashing + Sparse Attention

Idea: Locality-Sensitive Hashing (LSH) for Attention

Only attend to nearby or similar keys → reduce $ O(n^2) \to O(n \log n) $

LSH Attention (Reformer-style)

import torch.nn as nn

class LSHAttention(nn.Module):
    def __init__(self, d_model, n_hashes=4, bucket_size=64):
        super().__init__()
        self.n_hashes = n_hashes
        self.bucket_size = bucket_size
        self.d_model = d_model

    def hash_vectors(self, vectors):
        # Random rotation + bucket
        rotation_matrix = torch.randn(self.d_model, self.d_model)
        rotated_vecs = vectors @ rotation_matrix
        buckets = torch.argmax(rotated_vecs, dim=-1)
        return buckets

    def forward(self, Q, K, V):
        batch_size, seq_len, d = Q.shape

        # Multi-round LSH
        all_outputs = []
        all_weights = []

        for _ in range(self.n_hashes):
            buckets = self.hash_vectors(K)  # (batch, seq_len)
            sorted buckets, indices = torch.sort(buckets)

            # Chunk into buckets
            chunks = torch.split(indices, self.bucket_size, dim=1)

            # Approximate attention within chunks
            chunk_outs = []
            for chunk in chunks:
                Q_chunk = Q.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))
                K_chunk = K.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))
                V_chunk = V.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))

                out, w = scaled_dot_product_attention(Q_chunk, K_chunk, V_chunk)
                chunk_outs.append(out)

            output = torch.cat(chunk_outs, dim=1)
            all_outputs.append(output.unsqueeze(1))

        final_output = torch.mean(torch.cat(all_outputs, dim=1), dim=1)
        return final_output, None  # weights not meaningful

Memory: $ O(n \cdot b) $ where $ b $ = bucket size
Used in: Reformer, Longformer, BigBird


10. Full Multi-Head Attention (Transformer Block)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        batch, seq, _ = x.shape
        return x.view(batch, seq, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        batch, _, seq, d_k = x.shape
        return x.transpose(1, 2).contiguous().view(batch, seq, self.d_model)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        output = self.combine_heads(attn_output)
        return self.W_o(output), attn_weights

11. Full Example: Train Tiny Model

# Tiny dataset: learn to copy input
X = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8]], dtype=torch.long)
Y = X.clone()

model = MultiHeadAttention(d_model=16, num_heads=4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    optimizer.zero_grad()

    # Embed (simple)
    emb = nn.Embedding(10, 16)
    x = emb(X)

    output, _ = model(x, x, x)
    loss = F.mse_loss(output, x)

    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

12. Summary Cheat Sheet

Component Formula Purpose
Dot Product $ QK^T $ Similarity
Scaling $ \div \sqrt{d_k} $ Gradient stability
Softmax $ \text{softmax}(\cdot) $
Model Attention Type Efficiency
Transformer Full $ O(n^2) $ Baseline
Reformer LSH $ O(n \log n) $
Longformer Sliding Window $ O(n) $

13. Graph: Attention Scaling

import numpy as np

seq_lens = [128, 512, 2048, 8192, 32768]
full_mem = np.array(seq_lens)**2 * 4 / 1e9  # GB (float32)
lsh_mem = np.array(seq_lens) * 64 * 4 / 1e9  # bucket_size=64

plt.figure(figsize=(8, 5))
plt.plot(seq_lens, full_mem, 'r-o', label="Full Attention (O(n²))")
plt.plot(seq_lens, lsh_mem, 'g--s', label="LSH Attention (O(n log n))")
plt.yscale('log')
plt.xlabel("Sequence Length")
plt.ylabel("Memory (GB)")
plt.title("Attention Memory Scaling")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Practice Exercises

  1. Implement masked multi-head attention for a decoder.
  2. Replace softmax with sparsemax or entmax.
  3. Add relative position encodings.
  4. Use Performer's FAVOR+ (linear attention).
  5. Visualize attention on real sentences using Hugging Face.

Key Takeaways

Check Insight
Check Attention = weighted sum of values, guided by query-key similarity
Check Scaling prevents vanishing gradients
Check Causal mask → autoregressive generation
Check Hashing (LSH) → long sequences
Check Multi-head → multiple perspectives

Final Words

"Attention is All You Need" — not just a paper, but a paradigm shift.

You now have:
- Full mathematical understanding
- Working PyTorch code
- Visualization tools
- Efficiency tricks (hashing)


Next Steps:
Build a mini-Transformer from scratch → train on text → generate poetry!


End of Module
You just built the heart of GPT, BERT, and every modern LLM.
Attention is yours.

Last updated: Nov 13, 2025