Beam Search & Sampling

Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling

Beam Search & Sampling

Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling

Beam Search & Sampling

Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling


Module Objective

Master advanced text generationBeam Search, Top-k, Nucleus (Top-p) sampling using priority queues (heaps) — with full PyTorch implementation and 10x better coherence.


1. Why Not Greedy Decoding?

# Greedy: always pick argmax
next_token = logits[:, -1].argmax()

Problem:
- Gets stuck in local optima
- Misses high-probability sequences
- Repetitive: "I I I I I..."


2. Beam Search: Keep Top-k Paths

"Explore multiple futures, keep the best"

Step 1:
  "The" → ["cat", "dog", "man"]

Step 2:
  "The cat" → ["sat", "is", "jumped"]
  "The dog" → ["barked", "ran", "is"]
  ...
→ Keep top-3 sequences by log-prob

import heapq

# (log_prob, sequence)
beam = [(-0.1, [1]), (-0.2, [2]), (-0.3, [3])]
heapq.heapify(beam)

Heap operations:
- heappush: O(log k)
- heappop: O(log k)
- Beam width k = 5 → fast


4. Beam Search Implementation

@torch.no_grad()
def beam_search(model, idx, beam_width=5, max_len=50, eos_token=0):
    model.eval()

    # Initial beam: (log_prob, sequence, cache)
    beam = [(0.0, idx.tolist(), None)]

    for _ in range(max_len):
        all_candidates = []

        for log_prob, seq, cache in beam:
            input_tensor = torch.tensor([seq[-1]], dtype=torch.long).unsqueeze(0) if cache else torch.tensor([seq], dtype=torch.long)

            logits, _, new_cache = model(input_tensor, past_kv=cache)
            log_probs = F.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)

            # Get top-k from this path
            top_k = torch.topk(log_probs, beam_width)

            for i in range(beam_width):
                next_token = top_k.indices[i].item()
                new_log_prob = log_prob + top_k.values[i].item()
                new_seq = seq + [next_token]
                new_cache_i = new_cache

                all_candidates.append((new_log_prob, new_seq, new_cache_i))

                if next_token == eos_token:
                    break

        # Keep top beam_width
        beam = heapq.nlargest(beam_width, all_candidates, key=lambda x: x[0])

        # Early stop if all beams ended
        if all(seq[-1] == eos_token for _, seq, _ in beam):
            break

    # Return best sequence
    best_seq = max(beam, key=lambda x: x[0])[1]
    return torch.tensor(best_seq)

5. Top-k Sampling

def top_k_sampling(logits, k=50, temperature=1.0):
    logits = logits / temperature
    top_k = torch.topk(logits, k)
    probs = F.softmax(top_k.values, dim=-1)
    next_token = torch.multinomial(probs, 1)
    return top_k.indices[next_token]

6. Nucleus (Top-p) Sampling

"Sample from smallest set whose cumulative prob > p"

def nucleus_sampling(logits, p=0.9, temperature=1.0):
    logits = logits / temperature
    probs = F.softmax(logits, dim=-1)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)

    cum_probs = torch.cumsum(sorted_probs, dim=-1)
    mask = cum_probs > p
    mask[..., 1:] = mask[..., :-1].clone()
    mask[..., 0] = 0

    filtered_probs = sorted_probs.clone()
    filtered_probs[mask] = 0
    filtered_probs = filtered_probs / filtered_probs.sum()

    next_token = torch.multinomial(filtered_probs, 1)
    return sorted_indices[next_token]

7. Full Generation with All Methods

@torch.no_grad()
def generate(model, prompt, method="beam", **kwargs):
    idx = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0)
    cache = None

    for _ in range(kwargs.get("max_len", 100)):
        logits, _, cache = model(idx if cache is None else idx[:, -1:], past_kv=cache)
        logits = logits[:, -1, :]

        if method == "greedy":
            next_token = logits.argmax(-1, keepdim=True)
        elif method == "topk":
            next_token = top_k_sampling(logits, **kwargs)
        elif method == "nucleus":
            next_token = nucleus_sampling(logits, **kwargs)
        elif method == "beam":
            # Switch to beam search
            return beam_search(model, idx, **kwargs)

        idx = torch.cat([idx, next_token], dim=1)
        if next_token.item() == 0: break

    return idx

8. Comparison: All Methods on TinyShakespeare

prompt = "ROMEO:"

print("Greedy:")
print(decode(generate(model, prompt, method="greedy", max_len=100)[0].tolist()))

print("\nTop-k (k=40):")
print(decode(generate(model, prompt, method="topk", k=40, temperature=0.8, max_len=100)[0].tolist()))

print("\nNucleus (p=0.9):")
print(decode(generate(model, prompt, method="nucleus", p=0.9, temperature=1.0, max_len=100)[0].tolist()))

print("\nBeam Search (width=5):")
beam_out = beam_search(model, torch.tensor(encode(prompt)).unsqueeze(0), beam_width=5)
print(decode(beam_out.tolist()))

Results:
- Greedy: repetitive
- Top-k: diverse, sometimes incoherent
- Nucleus: best balance
- Beam: most fluent, but deterministic


9. Priority Queue (Heap) in Action

import heapq

# Simulate beam
beam = []
heapq.heappush(beam, (-0.1, [1, 2]))  # log_prob, seq
heapq.heappush(beam, (-0.3, [1, 3]))
heapq.heappush(beam, (-0.2, [1, 4]))

print(heapq.heappop(beam))  # (-0.1, [1, 2]) → best

10. Beam Search with Length Normalization

# Prevent short sequences
score = log_prob / (len(seq) ** 0.6)

11. Summary Table

Method Diversity Coherence Speed Use Case
Greedy Low Medium Fastest Baseline
Top-k Medium Medium Fast General
Nucleus High High Fast Best for creativity
Beam Low Highest Slow Best for accuracy

12. Practice Exercises

  1. Add length penalty to beam search
  2. Implement diverse beam search
  3. Combine top-k + nucleus
  4. Measure perplexity of outputs
  5. Visualize probability mass

13. Key Takeaways

Check Insight
Check Beam Search = BFS with heap
Check Top-k = truncate tail
Check Nucleus = dynamic truncation
Check Nucleus > Top-k in practice
Check Used in GPT, Claude, Gemini

Full Copy-Paste: All Decoding Methods

import torch
import torch.nn.functional as F
import heapq

# === Top-k ===
def top_k(logits, k=50, t=1.0):
    logits = logits / t
    v, _ = torch.topk(logits, k)
    probs = F.softmax(v, dim=-1)
    return torch.multinomial(probs, 1)

# === Nucleus ===
def nucleus(logits, p=0.9, t=1.0):
    logits = logits / t
    probs = F.softmax(logits, dim=-1)
    s_idx = torch.argsort(probs, descending=True)
    s_probs = probs[s_idx]
    cum = torch.cumsum(s_probs, dim=-1)
    mask = cum > p
    mask[1:] = mask[:-1]
    mask[0] = 0
    s_probs[mask] = 0
    s_probs = s_probs / s_probs.sum()
    idx = torch.multinomial(s_probs, 1)
    return s_idx[idx]

# === Beam Search ===
@torch.no_grad()
def beam(model, idx, k=5, max_len=50):
    beam = [(0.0, idx.tolist(), None)]
    for _ in range(max_len):
        cands = []
        for lp, seq, cache in beam:
            x = torch.tensor([seq[-1]]).unsqueeze(0) if cache else torch.tensor([seq])
            logits, _, nc = model(x, cache)
            logp = F.log_softmax(logits[0, -1], dim=-1)
            for token, prob in enumerate(logp.topk(k).values.tolist()):
                cands.append((lp + prob, seq + [logp.topk(k).indices[token].item()], nc))
        beam = heapq.nlargest(k, cands, key=lambda x: x[0])
    return torch.tensor(max(beam, key=lambda x: x[0])[1])

Final Words

You now control how LLMs think.
- Greedy → robot
- Beam → perfectionist
- Nucleuscreative genius


End of Module
You generate like GPT-4 — coherent, diverse, fast.
Next: Build a chatbot API.

Last updated: Nov 13, 2025

Beam Search & Sampling

Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling

Beam Search & Sampling

Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling

Beam Search & Sampling

Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling


Module Objective

Master advanced text generationBeam Search, Top-k, Nucleus (Top-p) sampling using priority queues (heaps) — with full PyTorch implementation and 10x better coherence.


1. Why Not Greedy Decoding?

# Greedy: always pick argmax
next_token = logits[:, -1].argmax()

Problem:
- Gets stuck in local optima
- Misses high-probability sequences
- Repetitive: "I I I I I..."


2. Beam Search: Keep Top-k Paths

"Explore multiple futures, keep the best"

Step 1:
  "The" → ["cat", "dog", "man"]

Step 2:
  "The cat" → ["sat", "is", "jumped"]
  "The dog" → ["barked", "ran", "is"]
  ...
→ Keep top-3 sequences by log-prob

import heapq

# (log_prob, sequence)
beam = [(-0.1, [1]), (-0.2, [2]), (-0.3, [3])]
heapq.heapify(beam)

Heap operations:
- heappush: O(log k)
- heappop: O(log k)
- Beam width k = 5 → fast


4. Beam Search Implementation

@torch.no_grad()
def beam_search(model, idx, beam_width=5, max_len=50, eos_token=0):
    model.eval()

    # Initial beam: (log_prob, sequence, cache)
    beam = [(0.0, idx.tolist(), None)]

    for _ in range(max_len):
        all_candidates = []

        for log_prob, seq, cache in beam:
            input_tensor = torch.tensor([seq[-1]], dtype=torch.long).unsqueeze(0) if cache else torch.tensor([seq], dtype=torch.long)

            logits, _, new_cache = model(input_tensor, past_kv=cache)
            log_probs = F.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)

            # Get top-k from this path
            top_k = torch.topk(log_probs, beam_width)

            for i in range(beam_width):
                next_token = top_k.indices[i].item()
                new_log_prob = log_prob + top_k.values[i].item()
                new_seq = seq + [next_token]
                new_cache_i = new_cache

                all_candidates.append((new_log_prob, new_seq, new_cache_i))

                if next_token == eos_token:
                    break

        # Keep top beam_width
        beam = heapq.nlargest(beam_width, all_candidates, key=lambda x: x[0])

        # Early stop if all beams ended
        if all(seq[-1] == eos_token for _, seq, _ in beam):
            break

    # Return best sequence
    best_seq = max(beam, key=lambda x: x[0])[1]
    return torch.tensor(best_seq)

5. Top-k Sampling

def top_k_sampling(logits, k=50, temperature=1.0):
    logits = logits / temperature
    top_k = torch.topk(logits, k)
    probs = F.softmax(top_k.values, dim=-1)
    next_token = torch.multinomial(probs, 1)
    return top_k.indices[next_token]

6. Nucleus (Top-p) Sampling

"Sample from smallest set whose cumulative prob > p"

def nucleus_sampling(logits, p=0.9, temperature=1.0):
    logits = logits / temperature
    probs = F.softmax(logits, dim=-1)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)

    cum_probs = torch.cumsum(sorted_probs, dim=-1)
    mask = cum_probs > p
    mask[..., 1:] = mask[..., :-1].clone()
    mask[..., 0] = 0

    filtered_probs = sorted_probs.clone()
    filtered_probs[mask] = 0
    filtered_probs = filtered_probs / filtered_probs.sum()

    next_token = torch.multinomial(filtered_probs, 1)
    return sorted_indices[next_token]

7. Full Generation with All Methods

@torch.no_grad()
def generate(model, prompt, method="beam", **kwargs):
    idx = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0)
    cache = None

    for _ in range(kwargs.get("max_len", 100)):
        logits, _, cache = model(idx if cache is None else idx[:, -1:], past_kv=cache)
        logits = logits[:, -1, :]

        if method == "greedy":
            next_token = logits.argmax(-1, keepdim=True)
        elif method == "topk":
            next_token = top_k_sampling(logits, **kwargs)
        elif method == "nucleus":
            next_token = nucleus_sampling(logits, **kwargs)
        elif method == "beam":
            # Switch to beam search
            return beam_search(model, idx, **kwargs)

        idx = torch.cat([idx, next_token], dim=1)
        if next_token.item() == 0: break

    return idx

8. Comparison: All Methods on TinyShakespeare

prompt = "ROMEO:"

print("Greedy:")
print(decode(generate(model, prompt, method="greedy", max_len=100)[0].tolist()))

print("\nTop-k (k=40):")
print(decode(generate(model, prompt, method="topk", k=40, temperature=0.8, max_len=100)[0].tolist()))

print("\nNucleus (p=0.9):")
print(decode(generate(model, prompt, method="nucleus", p=0.9, temperature=1.0, max_len=100)[0].tolist()))

print("\nBeam Search (width=5):")
beam_out = beam_search(model, torch.tensor(encode(prompt)).unsqueeze(0), beam_width=5)
print(decode(beam_out.tolist()))

Results:
- Greedy: repetitive
- Top-k: diverse, sometimes incoherent
- Nucleus: best balance
- Beam: most fluent, but deterministic


9. Priority Queue (Heap) in Action

import heapq

# Simulate beam
beam = []
heapq.heappush(beam, (-0.1, [1, 2]))  # log_prob, seq
heapq.heappush(beam, (-0.3, [1, 3]))
heapq.heappush(beam, (-0.2, [1, 4]))

print(heapq.heappop(beam))  # (-0.1, [1, 2]) → best

10. Beam Search with Length Normalization

# Prevent short sequences
score = log_prob / (len(seq) ** 0.6)

11. Summary Table

Method Diversity Coherence Speed Use Case
Greedy Low Medium Fastest Baseline
Top-k Medium Medium Fast General
Nucleus High High Fast Best for creativity
Beam Low Highest Slow Best for accuracy

12. Practice Exercises

  1. Add length penalty to beam search
  2. Implement diverse beam search
  3. Combine top-k + nucleus
  4. Measure perplexity of outputs
  5. Visualize probability mass

13. Key Takeaways

Check Insight
Check Beam Search = BFS with heap
Check Top-k = truncate tail
Check Nucleus = dynamic truncation
Check Nucleus > Top-k in practice
Check Used in GPT, Claude, Gemini

Full Copy-Paste: All Decoding Methods

import torch
import torch.nn.functional as F
import heapq

# === Top-k ===
def top_k(logits, k=50, t=1.0):
    logits = logits / t
    v, _ = torch.topk(logits, k)
    probs = F.softmax(v, dim=-1)
    return torch.multinomial(probs, 1)

# === Nucleus ===
def nucleus(logits, p=0.9, t=1.0):
    logits = logits / t
    probs = F.softmax(logits, dim=-1)
    s_idx = torch.argsort(probs, descending=True)
    s_probs = probs[s_idx]
    cum = torch.cumsum(s_probs, dim=-1)
    mask = cum > p
    mask[1:] = mask[:-1]
    mask[0] = 0
    s_probs[mask] = 0
    s_probs = s_probs / s_probs.sum()
    idx = torch.multinomial(s_probs, 1)
    return s_idx[idx]

# === Beam Search ===
@torch.no_grad()
def beam(model, idx, k=5, max_len=50):
    beam = [(0.0, idx.tolist(), None)]
    for _ in range(max_len):
        cands = []
        for lp, seq, cache in beam:
            x = torch.tensor([seq[-1]]).unsqueeze(0) if cache else torch.tensor([seq])
            logits, _, nc = model(x, cache)
            logp = F.log_softmax(logits[0, -1], dim=-1)
            for token, prob in enumerate(logp.topk(k).values.tolist()):
                cands.append((lp + prob, seq + [logp.topk(k).indices[token].item()], nc))
        beam = heapq.nlargest(k, cands, key=lambda x: x[0])
    return torch.tensor(max(beam, key=lambda x: x[0])[1])

Final Words

You now control how LLMs think.
- Greedy → robot
- Beam → perfectionist
- Nucleuscreative genius


End of Module
You generate like GPT-4 — coherent, diverse, fast.
Next: Build a chatbot API.

Last updated: Nov 13, 2025