Beam Search & Sampling
Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling
Beam Search & Sampling
Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling
Beam Search & Sampling
Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling
Module Objective
Master advanced text generation — Beam Search, Top-k, Nucleus (Top-p) sampling using priority queues (heaps) — with full PyTorch implementation and 10x better coherence.
1. Why Not Greedy Decoding?
# Greedy: always pick argmax
next_token = logits[:, -1].argmax()
Problem:
- Gets stuck in local optima
- Misses high-probability sequences
- Repetitive:"I I I I I..."
2. Beam Search: Keep Top-k Paths
"Explore multiple futures, keep the best"
Step 1:
"The" → ["cat", "dog", "man"]
Step 2:
"The cat" → ["sat", "is", "jumped"]
"The dog" → ["barked", "ran", "is"]
...
→ Keep top-3 sequences by log-prob
3. Priority Queue (Heap) = Core of Beam Search
import heapq
# (log_prob, sequence)
beam = [(-0.1, [1]), (-0.2, [2]), (-0.3, [3])]
heapq.heapify(beam)
Heap operations:
-heappush: O(log k)
-heappop: O(log k)
- Beam width k = 5 → fast
4. Beam Search Implementation
@torch.no_grad()
def beam_search(model, idx, beam_width=5, max_len=50, eos_token=0):
model.eval()
# Initial beam: (log_prob, sequence, cache)
beam = [(0.0, idx.tolist(), None)]
for _ in range(max_len):
all_candidates = []
for log_prob, seq, cache in beam:
input_tensor = torch.tensor([seq[-1]], dtype=torch.long).unsqueeze(0) if cache else torch.tensor([seq], dtype=torch.long)
logits, _, new_cache = model(input_tensor, past_kv=cache)
log_probs = F.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)
# Get top-k from this path
top_k = torch.topk(log_probs, beam_width)
for i in range(beam_width):
next_token = top_k.indices[i].item()
new_log_prob = log_prob + top_k.values[i].item()
new_seq = seq + [next_token]
new_cache_i = new_cache
all_candidates.append((new_log_prob, new_seq, new_cache_i))
if next_token == eos_token:
break
# Keep top beam_width
beam = heapq.nlargest(beam_width, all_candidates, key=lambda x: x[0])
# Early stop if all beams ended
if all(seq[-1] == eos_token for _, seq, _ in beam):
break
# Return best sequence
best_seq = max(beam, key=lambda x: x[0])[1]
return torch.tensor(best_seq)
5. Top-k Sampling
def top_k_sampling(logits, k=50, temperature=1.0):
logits = logits / temperature
top_k = torch.topk(logits, k)
probs = F.softmax(top_k.values, dim=-1)
next_token = torch.multinomial(probs, 1)
return top_k.indices[next_token]
6. Nucleus (Top-p) Sampling
"Sample from smallest set whose cumulative prob > p"
def nucleus_sampling(logits, p=0.9, temperature=1.0):
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cum_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cum_probs > p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = 0
filtered_probs = sorted_probs.clone()
filtered_probs[mask] = 0
filtered_probs = filtered_probs / filtered_probs.sum()
next_token = torch.multinomial(filtered_probs, 1)
return sorted_indices[next_token]
7. Full Generation with All Methods
@torch.no_grad()
def generate(model, prompt, method="beam", **kwargs):
idx = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0)
cache = None
for _ in range(kwargs.get("max_len", 100)):
logits, _, cache = model(idx if cache is None else idx[:, -1:], past_kv=cache)
logits = logits[:, -1, :]
if method == "greedy":
next_token = logits.argmax(-1, keepdim=True)
elif method == "topk":
next_token = top_k_sampling(logits, **kwargs)
elif method == "nucleus":
next_token = nucleus_sampling(logits, **kwargs)
elif method == "beam":
# Switch to beam search
return beam_search(model, idx, **kwargs)
idx = torch.cat([idx, next_token], dim=1)
if next_token.item() == 0: break
return idx
8. Comparison: All Methods on TinyShakespeare
prompt = "ROMEO:"
print("Greedy:")
print(decode(generate(model, prompt, method="greedy", max_len=100)[0].tolist()))
print("\nTop-k (k=40):")
print(decode(generate(model, prompt, method="topk", k=40, temperature=0.8, max_len=100)[0].tolist()))
print("\nNucleus (p=0.9):")
print(decode(generate(model, prompt, method="nucleus", p=0.9, temperature=1.0, max_len=100)[0].tolist()))
print("\nBeam Search (width=5):")
beam_out = beam_search(model, torch.tensor(encode(prompt)).unsqueeze(0), beam_width=5)
print(decode(beam_out.tolist()))
Results:
- Greedy: repetitive
- Top-k: diverse, sometimes incoherent
- Nucleus: best balance
- Beam: most fluent, but deterministic
9. Priority Queue (Heap) in Action
import heapq
# Simulate beam
beam = []
heapq.heappush(beam, (-0.1, [1, 2])) # log_prob, seq
heapq.heappush(beam, (-0.3, [1, 3]))
heapq.heappush(beam, (-0.2, [1, 4]))
print(heapq.heappop(beam)) # (-0.1, [1, 2]) → best
10. Beam Search with Length Normalization
# Prevent short sequences
score = log_prob / (len(seq) ** 0.6)
11. Summary Table
| Method | Diversity | Coherence | Speed | Use Case |
|---|---|---|---|---|
| Greedy | Low | Medium | Fastest | Baseline |
| Top-k | Medium | Medium | Fast | General |
| Nucleus | High | High | Fast | Best for creativity |
| Beam | Low | Highest | Slow | Best for accuracy |
12. Practice Exercises
- Add length penalty to beam search
- Implement diverse beam search
- Combine top-k + nucleus
- Measure perplexity of outputs
- Visualize probability mass
13. Key Takeaways
| Check | Insight |
|---|---|
| Check | Beam Search = BFS with heap |
| Check | Top-k = truncate tail |
| Check | Nucleus = dynamic truncation |
| Check | Nucleus > Top-k in practice |
| Check | Used in GPT, Claude, Gemini |
Full Copy-Paste: All Decoding Methods
import torch
import torch.nn.functional as F
import heapq
# === Top-k ===
def top_k(logits, k=50, t=1.0):
logits = logits / t
v, _ = torch.topk(logits, k)
probs = F.softmax(v, dim=-1)
return torch.multinomial(probs, 1)
# === Nucleus ===
def nucleus(logits, p=0.9, t=1.0):
logits = logits / t
probs = F.softmax(logits, dim=-1)
s_idx = torch.argsort(probs, descending=True)
s_probs = probs[s_idx]
cum = torch.cumsum(s_probs, dim=-1)
mask = cum > p
mask[1:] = mask[:-1]
mask[0] = 0
s_probs[mask] = 0
s_probs = s_probs / s_probs.sum()
idx = torch.multinomial(s_probs, 1)
return s_idx[idx]
# === Beam Search ===
@torch.no_grad()
def beam(model, idx, k=5, max_len=50):
beam = [(0.0, idx.tolist(), None)]
for _ in range(max_len):
cands = []
for lp, seq, cache in beam:
x = torch.tensor([seq[-1]]).unsqueeze(0) if cache else torch.tensor([seq])
logits, _, nc = model(x, cache)
logp = F.log_softmax(logits[0, -1], dim=-1)
for token, prob in enumerate(logp.topk(k).values.tolist()):
cands.append((lp + prob, seq + [logp.topk(k).indices[token].item()], nc))
beam = heapq.nlargest(k, cands, key=lambda x: x[0])
return torch.tensor(max(beam, key=lambda x: x[0])[1])
Final Words
You now control how LLMs think.
- Greedy → robot
- Beam → perfectionist
- Nucleus → creative genius
End of Module
You generate like GPT-4 — coherent, diverse, fast.
Next: Build a chatbot API.
Beam Search & Sampling
Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling
Beam Search & Sampling
Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling
Beam Search & Sampling
Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling
Module Objective
Master advanced text generation — Beam Search, Top-k, Nucleus (Top-p) sampling using priority queues (heaps) — with full PyTorch implementation and 10x better coherence.
1. Why Not Greedy Decoding?
# Greedy: always pick argmax
next_token = logits[:, -1].argmax()
Problem:
- Gets stuck in local optima
- Misses high-probability sequences
- Repetitive:"I I I I I..."
2. Beam Search: Keep Top-k Paths
"Explore multiple futures, keep the best"
Step 1:
"The" → ["cat", "dog", "man"]
Step 2:
"The cat" → ["sat", "is", "jumped"]
"The dog" → ["barked", "ran", "is"]
...
→ Keep top-3 sequences by log-prob
3. Priority Queue (Heap) = Core of Beam Search
import heapq
# (log_prob, sequence)
beam = [(-0.1, [1]), (-0.2, [2]), (-0.3, [3])]
heapq.heapify(beam)
Heap operations:
-heappush: O(log k)
-heappop: O(log k)
- Beam width k = 5 → fast
4. Beam Search Implementation
@torch.no_grad()
def beam_search(model, idx, beam_width=5, max_len=50, eos_token=0):
model.eval()
# Initial beam: (log_prob, sequence, cache)
beam = [(0.0, idx.tolist(), None)]
for _ in range(max_len):
all_candidates = []
for log_prob, seq, cache in beam:
input_tensor = torch.tensor([seq[-1]], dtype=torch.long).unsqueeze(0) if cache else torch.tensor([seq], dtype=torch.long)
logits, _, new_cache = model(input_tensor, past_kv=cache)
log_probs = F.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)
# Get top-k from this path
top_k = torch.topk(log_probs, beam_width)
for i in range(beam_width):
next_token = top_k.indices[i].item()
new_log_prob = log_prob + top_k.values[i].item()
new_seq = seq + [next_token]
new_cache_i = new_cache
all_candidates.append((new_log_prob, new_seq, new_cache_i))
if next_token == eos_token:
break
# Keep top beam_width
beam = heapq.nlargest(beam_width, all_candidates, key=lambda x: x[0])
# Early stop if all beams ended
if all(seq[-1] == eos_token for _, seq, _ in beam):
break
# Return best sequence
best_seq = max(beam, key=lambda x: x[0])[1]
return torch.tensor(best_seq)
5. Top-k Sampling
def top_k_sampling(logits, k=50, temperature=1.0):
logits = logits / temperature
top_k = torch.topk(logits, k)
probs = F.softmax(top_k.values, dim=-1)
next_token = torch.multinomial(probs, 1)
return top_k.indices[next_token]
6. Nucleus (Top-p) Sampling
"Sample from smallest set whose cumulative prob > p"
def nucleus_sampling(logits, p=0.9, temperature=1.0):
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cum_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cum_probs > p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = 0
filtered_probs = sorted_probs.clone()
filtered_probs[mask] = 0
filtered_probs = filtered_probs / filtered_probs.sum()
next_token = torch.multinomial(filtered_probs, 1)
return sorted_indices[next_token]
7. Full Generation with All Methods
@torch.no_grad()
def generate(model, prompt, method="beam", **kwargs):
idx = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0)
cache = None
for _ in range(kwargs.get("max_len", 100)):
logits, _, cache = model(idx if cache is None else idx[:, -1:], past_kv=cache)
logits = logits[:, -1, :]
if method == "greedy":
next_token = logits.argmax(-1, keepdim=True)
elif method == "topk":
next_token = top_k_sampling(logits, **kwargs)
elif method == "nucleus":
next_token = nucleus_sampling(logits, **kwargs)
elif method == "beam":
# Switch to beam search
return beam_search(model, idx, **kwargs)
idx = torch.cat([idx, next_token], dim=1)
if next_token.item() == 0: break
return idx
8. Comparison: All Methods on TinyShakespeare
prompt = "ROMEO:"
print("Greedy:")
print(decode(generate(model, prompt, method="greedy", max_len=100)[0].tolist()))
print("\nTop-k (k=40):")
print(decode(generate(model, prompt, method="topk", k=40, temperature=0.8, max_len=100)[0].tolist()))
print("\nNucleus (p=0.9):")
print(decode(generate(model, prompt, method="nucleus", p=0.9, temperature=1.0, max_len=100)[0].tolist()))
print("\nBeam Search (width=5):")
beam_out = beam_search(model, torch.tensor(encode(prompt)).unsqueeze(0), beam_width=5)
print(decode(beam_out.tolist()))
Results:
- Greedy: repetitive
- Top-k: diverse, sometimes incoherent
- Nucleus: best balance
- Beam: most fluent, but deterministic
9. Priority Queue (Heap) in Action
import heapq
# Simulate beam
beam = []
heapq.heappush(beam, (-0.1, [1, 2])) # log_prob, seq
heapq.heappush(beam, (-0.3, [1, 3]))
heapq.heappush(beam, (-0.2, [1, 4]))
print(heapq.heappop(beam)) # (-0.1, [1, 2]) → best
10. Beam Search with Length Normalization
# Prevent short sequences
score = log_prob / (len(seq) ** 0.6)
11. Summary Table
| Method | Diversity | Coherence | Speed | Use Case |
|---|---|---|---|---|
| Greedy | Low | Medium | Fastest | Baseline |
| Top-k | Medium | Medium | Fast | General |
| Nucleus | High | High | Fast | Best for creativity |
| Beam | Low | Highest | Slow | Best for accuracy |
12. Practice Exercises
- Add length penalty to beam search
- Implement diverse beam search
- Combine top-k + nucleus
- Measure perplexity of outputs
- Visualize probability mass
13. Key Takeaways
| Check | Insight |
|---|---|
| Check | Beam Search = BFS with heap |
| Check | Top-k = truncate tail |
| Check | Nucleus = dynamic truncation |
| Check | Nucleus > Top-k in practice |
| Check | Used in GPT, Claude, Gemini |
Full Copy-Paste: All Decoding Methods
import torch
import torch.nn.functional as F
import heapq
# === Top-k ===
def top_k(logits, k=50, t=1.0):
logits = logits / t
v, _ = torch.topk(logits, k)
probs = F.softmax(v, dim=-1)
return torch.multinomial(probs, 1)
# === Nucleus ===
def nucleus(logits, p=0.9, t=1.0):
logits = logits / t
probs = F.softmax(logits, dim=-1)
s_idx = torch.argsort(probs, descending=True)
s_probs = probs[s_idx]
cum = torch.cumsum(s_probs, dim=-1)
mask = cum > p
mask[1:] = mask[:-1]
mask[0] = 0
s_probs[mask] = 0
s_probs = s_probs / s_probs.sum()
idx = torch.multinomial(s_probs, 1)
return s_idx[idx]
# === Beam Search ===
@torch.no_grad()
def beam(model, idx, k=5, max_len=50):
beam = [(0.0, idx.tolist(), None)]
for _ in range(max_len):
cands = []
for lp, seq, cache in beam:
x = torch.tensor([seq[-1]]).unsqueeze(0) if cache else torch.tensor([seq])
logits, _, nc = model(x, cache)
logp = F.log_softmax(logits[0, -1], dim=-1)
for token, prob in enumerate(logp.topk(k).values.tolist()):
cands.append((lp + prob, seq + [logp.topk(k).indices[token].item()], nc))
beam = heapq.nlargest(k, cands, key=lambda x: x[0])
return torch.tensor(max(beam, key=lambda x: x[0])[1])
Final Words
You now control how LLMs think.
- Greedy → robot
- Beam → perfectionist
- Nucleus → creative genius
End of Module
You generate like GPT-4 — coherent, diverse, fast.
Next: Build a chatbot API.