Encoder-Decoder Transformers

Complete Module: Cross-Attention, Seq2Seq, Machine Translation, Mini-T5 (64-dim)

Encoder-Decoder Transformers

Complete Module: Cross-Attention, Seq2Seq, Machine Translation, Mini-T5 (64-dim)

Encoder-Decoder Transformers

Complete Module: Cross-Attention, Seq2Seq, Machine Translation, Mini-T5 (64-dim)


Module Objective

Master the Encoder-Decoder Transformercross-attention, seq2seq, machine translation, with full PyTorch implementation of a Mini-T5 (64-dim).


1. Encoder-Decoder vs Decoder-Only

Architecture Use Case Attention Types
Decoder-Only Text generation Self-attention (causal)
Encoder-Decoder Translation, Summarization Self + Cross

2. Encoder-Decoder Architecture

Input (src) → [Encoder] → Memory (K, V)
                         ↘
Output (tgt) → [Decoder] → Cross-Attention(Q from tgt, K,V from src)
  • Encoder: Bidirectional self-attention
  • Decoder: Causal self-attention + cross-attention

3. Three Types of Attention

# 1. Encoder: Self-Attention (bidirectional)
attn_enc = MultiHeadAttention(enc_x, enc_x, enc_x)

# 2. Decoder: Self-Attention (causal)
attn_dec_self = MultiHeadAttention(dec_x, dec_x, dec_x, mask=causal_mask)

# 3. Decoder: Cross-Attention
attn_cross = MultiHeadAttention(dec_x, enc_x, enc_x)  # Q=dec, K=V=enc

4. Full Encoder-Decoder Block

class EncoderBlock(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = MultiHeadAttention(n_embd, n_head)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ff = FeedForward(n_embd)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))[0]
        x = x + self.ff(self.ln2(x))
        return x

class DecoderBlock(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.self_attn = CausalMultiHeadAttention(n_embd, n_head)
        self.ln2 = nn.LayerNorm(n_embd)
        self.cross_attn = MultiHeadAttention(n_embd, n_head)
        self.ln3 = nn.LayerNorm(n_embd)
        self.ff = FeedForward(n_embd)
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Self-attention (causal)
        x = x + self.self_attn(self.ln1(x), mask=tgt_mask)[0]
        # Cross-attention
        x = x + self.cross_attn(self.ln2(x), enc_output, enc_output, mask=src_mask)[0]
        # FFN
        x = x + self.ff(self.ln3(x))
        return x

5. Mini-T5 (64-dim) — Full Implementation

class MiniT5(nn.Module):
    def __init__(self, src_vocab=1000, tgt_vocab=1000, n_embd=64, n_head=4, n_layer=3, max_len=128):
        super().__init__()
        self.n_embd = n_embd
        self.src_emb = nn.Embedding(src_vocab, n_embd)
        self.tgt_emb = nn.Embedding(tgt_vocab, n_embd)
        self.pos_emb = nn.Embedding(max_len, n_embd)

        self.encoder = nn.ModuleList([EncoderBlock(n_embd, n_head) for _ in range(n_layer)])
        self.decoder = nn.ModuleList([DecoderBlock(n_embd, n_head) for _ in range(n_layer)])

        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, tgt_vocab, bias=False)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, std=0.02)

    def encode(self, src, src_mask=None):
        x = self.src_emb(src) + self.pos_emb(torch.arange(src.size(1), device=src.device))
        for block in self.encoder:
            x = block(x)
        return x  # Memory for cross-attention

    def decode(self, tgt, memory, src_mask=None, tgt_mask=None):
        x = self.tgt_emb(tgt) + self.pos_emb(torch.arange(tgt.size(1), device=tgt.device))
        for block in self.decoder:
            x = block(x, memory, src_mask, tgt_mask)
        x = self.ln_f(x)
        return self.lm_head(x)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        memory = self.encode(src, src_mask)
        logits = self.decode(tgt, memory, src_mask, tgt_mask)
        return logits

6. Masks

def create_padding_mask(seq, pad_idx=0):
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)  # (B,1,1,S)

def create_causal_mask(seq_len):
    return torch.triu(torch.ones(seq_len, seq_len), diagonal=1) == 0

7. Training: Teacher Forcing

# src: "hello" → [100, 200, 300, 0]
# tgt: "<s> bonjour </s>" → [1, 400, 500, 2]
# tgt_input = tgt[:, :-1], tgt_label = tgt[:, 1:]

logits = model(src, tgt_input, src_mask, tgt_mask)
loss = F.cross_entropy(logits.view(-1, tgt_vocab), tgt_label.view(-1))

8. Inference: Autoregressive Decoding

@torch.no_grad()
def generate(self, src, max_len=50, sos_idx=1, eos_idx=2):
    memory = self.encode(src)
    tgt = torch.tensor([[sos_idx]], device=src.device)

    for _ in range(max_len):
        logits = self.decode(tgt, memory)
        next_token = logits[:, -1, :].argmax(-1, keepdim=True)
        tgt = torch.cat([tgt, next_token], dim=1)
        if next_token.item() == eos_idx:
            break
    return tgt

9. Mini Translation Dataset (English → French)

pairs = [
    ("hello", "bonjour"),
    ("thank you", "merci"),
    ("good morning", "bonjour"),
    ("how are you", "comment allez-vous"),
    ("i love you", "je t'aime"),
]

# Build vocab
src_vocab = {'<pad>':0, '<s>':1, '</s>':2}
tgt_vocab = {'<pad>':0, '<s>':1, '</s>':2}

for en, fr in pairs:
    for w in en.split(): src_vocab.setdefault(w, len(src_vocab))
    for w in fr.split(): tgt_vocab.setdefault(w, len(tgt_vocab))

# Encode
def encode_pair(en, fr):
    src = [1] + [src_vocab[w] for w in en.split()] + [2]
    tgt = [1] + [tgt_vocab[w] for w in fr.split()] + [2]
    return torch.tensor(src), torch.tensor(tgt)

dataset = [encode_pair(en, fr) for en, fr in pairs]

10. Train Mini-T5

model = MiniT5(len(src_vocab), len(tgt_vocab), n_embd=64, n_head=4, n_layer=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(500):
    total_loss = 0
    for src, tgt in dataset:
        src = src.unsqueeze(0)
        tgt_input = tgt[:-1].unsqueeze(0)
        tgt_label = tgt[1:].unsqueeze(0)

        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = create_causal_mask(tgt_input.size(1))

        logits = model(src, tgt_input, src_mask, tgt_mask)
        loss = F.cross_entropy(logits.view(-1, len(tgt_vocab)), tgt_label.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss/len(dataset):.4f}")

11. Generate Translation

src = torch.tensor([[1, src_vocab["hello"], 2]])  # <s> hello </s>
output = model.generate(src)
translated = [tgt_vocab_inv[i] for i in output[0].tolist() if i > 2]
print("Input:", "hello")
print("Output:", " ".join(translated))
# → "bonjour"

12. Visualization: Cross-Attention Heatmap

import matplotlib.pyplot as plt
import seaborn as sns

# Hook cross-attention
attn_weights = []
def hook(module, input, output):
    attn_weights.append(output[1].detach())

handle = model.decoder[0].cross_attn.register_forward_hook(hook)

src = torch.tensor([[1, src_vocab["thank"], src_vocab["you"], 2]])
tgt = torch.tensor([[1]])  # <s>
model.generate(src)

handle.remove()

# Plot
sns.heatmap(attn_weights[0][0, 0].cpu(), annot=True, cmap="Blues",
            xticklabels=["<s>", "thank", "you", "</s>"],
            yticklabels=["<s>"])
plt.title("Cross-Attention: Decoder <s> → Encoder")
plt.show()

13. Summary Table

Component Encoder Decoder
Self-Attention Bidirectional Causal
Cross-Attention No Yes (Q=dec, K=V=enc)
Mask Padding Padding + Causal
Output Memory Translation

14. Practice Exercises

  1. Add beam search
  2. Train on reverse (French → English)
  3. Visualize all attention heads
  4. Add shared embeddings
  5. Implement T5-style text-to-text

15. Key Takeaways

Check Insight
Check Encoder = context encoder
Check Decoder = autoregressive + cross-attention
Check Cross-attention = conditioned generation
Check Mini-T5 works with 64-dim!
Check Used in T5, BART, MT

Full Copy-Paste Mini-T5

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.n_head = n_head
        self.d_k = n_embd // n_head
        self.Wq = nn.Linear(n_embd, n_embd)
        self.Wk = nn.Linear(n_embd, n_embd)
        self.Wv = nn.Linear(n_embd, n_embd)
        self.Wo = nn.Linear(n_embd, n_embd)
    def forward(self, Q, K, V, mask=None):
        B, T, C = Q.shape
        q = self.Wq(Q).view(B, T, self.n_head, self.d_k).transpose(1,2)
        k = self.Wk(K).view(B, T, self.n_head, self.d_k).transpose(1,2)
        v = self.Wv(V).view(B, T, self.n_head, self.d_k).transpose(1,2)
        att = (q @ k.transpose(-2,-1)) / (self.d_k**0.5)
        if mask is not None: att = att.masked_fill(~mask, -1e9)
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1,2).contiguous().view(B, T, C)
        return self.Wo(y), att

class CausalMultiHeadAttention(MultiHeadAttention):
    def forward(self, x, mask=None):
        B, T, C = x.shape
        causal_mask = torch.tril(torch.ones(T, T, device=x.device)).bool()
        if mask is not None: mask = mask & causal_mask
        else: mask = causal_mask
        return super().forward(x, x, x, mask)

class FeedForward(nn.Module):
    def __init__(self, n_embd): super().__init__(); self.net = nn.Sequential(nn.Linear(n_embd, n_embd*4), nn.GELU(), nn.Linear(n_embd*4, n_embd))
    def forward(self, x): return self.net(x)

# EncoderBlock, DecoderBlock, MiniT5 as above

Final Words

You just built T5 from scratch.
Encoder-Decoder = the original Transformer — still powers translation, summarization, and more.


End of Module
You now control both halves of the Transformer family.
Next: Pretrain a 125M model.

Last updated: Nov 13, 2025

Encoder-Decoder Transformers

Complete Module: Cross-Attention, Seq2Seq, Machine Translation, Mini-T5 (64-dim)

Encoder-Decoder Transformers

Complete Module: Cross-Attention, Seq2Seq, Machine Translation, Mini-T5 (64-dim)

Encoder-Decoder Transformers

Complete Module: Cross-Attention, Seq2Seq, Machine Translation, Mini-T5 (64-dim)


Module Objective

Master the Encoder-Decoder Transformercross-attention, seq2seq, machine translation, with full PyTorch implementation of a Mini-T5 (64-dim).


1. Encoder-Decoder vs Decoder-Only

Architecture Use Case Attention Types
Decoder-Only Text generation Self-attention (causal)
Encoder-Decoder Translation, Summarization Self + Cross

2. Encoder-Decoder Architecture

Input (src) → [Encoder] → Memory (K, V)
                         ↘
Output (tgt) → [Decoder] → Cross-Attention(Q from tgt, K,V from src)
  • Encoder: Bidirectional self-attention
  • Decoder: Causal self-attention + cross-attention

3. Three Types of Attention

# 1. Encoder: Self-Attention (bidirectional)
attn_enc = MultiHeadAttention(enc_x, enc_x, enc_x)

# 2. Decoder: Self-Attention (causal)
attn_dec_self = MultiHeadAttention(dec_x, dec_x, dec_x, mask=causal_mask)

# 3. Decoder: Cross-Attention
attn_cross = MultiHeadAttention(dec_x, enc_x, enc_x)  # Q=dec, K=V=enc

4. Full Encoder-Decoder Block

class EncoderBlock(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = MultiHeadAttention(n_embd, n_head)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ff = FeedForward(n_embd)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))[0]
        x = x + self.ff(self.ln2(x))
        return x

class DecoderBlock(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.self_attn = CausalMultiHeadAttention(n_embd, n_head)
        self.ln2 = nn.LayerNorm(n_embd)
        self.cross_attn = MultiHeadAttention(n_embd, n_head)
        self.ln3 = nn.LayerNorm(n_embd)
        self.ff = FeedForward(n_embd)
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Self-attention (causal)
        x = x + self.self_attn(self.ln1(x), mask=tgt_mask)[0]
        # Cross-attention
        x = x + self.cross_attn(self.ln2(x), enc_output, enc_output, mask=src_mask)[0]
        # FFN
        x = x + self.ff(self.ln3(x))
        return x

5. Mini-T5 (64-dim) — Full Implementation

class MiniT5(nn.Module):
    def __init__(self, src_vocab=1000, tgt_vocab=1000, n_embd=64, n_head=4, n_layer=3, max_len=128):
        super().__init__()
        self.n_embd = n_embd
        self.src_emb = nn.Embedding(src_vocab, n_embd)
        self.tgt_emb = nn.Embedding(tgt_vocab, n_embd)
        self.pos_emb = nn.Embedding(max_len, n_embd)

        self.encoder = nn.ModuleList([EncoderBlock(n_embd, n_head) for _ in range(n_layer)])
        self.decoder = nn.ModuleList([DecoderBlock(n_embd, n_head) for _ in range(n_layer)])

        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, tgt_vocab, bias=False)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, std=0.02)

    def encode(self, src, src_mask=None):
        x = self.src_emb(src) + self.pos_emb(torch.arange(src.size(1), device=src.device))
        for block in self.encoder:
            x = block(x)
        return x  # Memory for cross-attention

    def decode(self, tgt, memory, src_mask=None, tgt_mask=None):
        x = self.tgt_emb(tgt) + self.pos_emb(torch.arange(tgt.size(1), device=tgt.device))
        for block in self.decoder:
            x = block(x, memory, src_mask, tgt_mask)
        x = self.ln_f(x)
        return self.lm_head(x)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        memory = self.encode(src, src_mask)
        logits = self.decode(tgt, memory, src_mask, tgt_mask)
        return logits

6. Masks

def create_padding_mask(seq, pad_idx=0):
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)  # (B,1,1,S)

def create_causal_mask(seq_len):
    return torch.triu(torch.ones(seq_len, seq_len), diagonal=1) == 0

7. Training: Teacher Forcing

# src: "hello" → [100, 200, 300, 0]
# tgt: "<s> bonjour </s>" → [1, 400, 500, 2]
# tgt_input = tgt[:, :-1], tgt_label = tgt[:, 1:]

logits = model(src, tgt_input, src_mask, tgt_mask)
loss = F.cross_entropy(logits.view(-1, tgt_vocab), tgt_label.view(-1))

8. Inference: Autoregressive Decoding

@torch.no_grad()
def generate(self, src, max_len=50, sos_idx=1, eos_idx=2):
    memory = self.encode(src)
    tgt = torch.tensor([[sos_idx]], device=src.device)

    for _ in range(max_len):
        logits = self.decode(tgt, memory)
        next_token = logits[:, -1, :].argmax(-1, keepdim=True)
        tgt = torch.cat([tgt, next_token], dim=1)
        if next_token.item() == eos_idx:
            break
    return tgt

9. Mini Translation Dataset (English → French)

pairs = [
    ("hello", "bonjour"),
    ("thank you", "merci"),
    ("good morning", "bonjour"),
    ("how are you", "comment allez-vous"),
    ("i love you", "je t'aime"),
]

# Build vocab
src_vocab = {'<pad>':0, '<s>':1, '</s>':2}
tgt_vocab = {'<pad>':0, '<s>':1, '</s>':2}

for en, fr in pairs:
    for w in en.split(): src_vocab.setdefault(w, len(src_vocab))
    for w in fr.split(): tgt_vocab.setdefault(w, len(tgt_vocab))

# Encode
def encode_pair(en, fr):
    src = [1] + [src_vocab[w] for w in en.split()] + [2]
    tgt = [1] + [tgt_vocab[w] for w in fr.split()] + [2]
    return torch.tensor(src), torch.tensor(tgt)

dataset = [encode_pair(en, fr) for en, fr in pairs]

10. Train Mini-T5

model = MiniT5(len(src_vocab), len(tgt_vocab), n_embd=64, n_head=4, n_layer=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(500):
    total_loss = 0
    for src, tgt in dataset:
        src = src.unsqueeze(0)
        tgt_input = tgt[:-1].unsqueeze(0)
        tgt_label = tgt[1:].unsqueeze(0)

        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = create_causal_mask(tgt_input.size(1))

        logits = model(src, tgt_input, src_mask, tgt_mask)
        loss = F.cross_entropy(logits.view(-1, len(tgt_vocab)), tgt_label.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss/len(dataset):.4f}")

11. Generate Translation

src = torch.tensor([[1, src_vocab["hello"], 2]])  # <s> hello </s>
output = model.generate(src)
translated = [tgt_vocab_inv[i] for i in output[0].tolist() if i > 2]
print("Input:", "hello")
print("Output:", " ".join(translated))
# → "bonjour"

12. Visualization: Cross-Attention Heatmap

import matplotlib.pyplot as plt
import seaborn as sns

# Hook cross-attention
attn_weights = []
def hook(module, input, output):
    attn_weights.append(output[1].detach())

handle = model.decoder[0].cross_attn.register_forward_hook(hook)

src = torch.tensor([[1, src_vocab["thank"], src_vocab["you"], 2]])
tgt = torch.tensor([[1]])  # <s>
model.generate(src)

handle.remove()

# Plot
sns.heatmap(attn_weights[0][0, 0].cpu(), annot=True, cmap="Blues",
            xticklabels=["<s>", "thank", "you", "</s>"],
            yticklabels=["<s>"])
plt.title("Cross-Attention: Decoder <s> → Encoder")
plt.show()

13. Summary Table

Component Encoder Decoder
Self-Attention Bidirectional Causal
Cross-Attention No Yes (Q=dec, K=V=enc)
Mask Padding Padding + Causal
Output Memory Translation

14. Practice Exercises

  1. Add beam search
  2. Train on reverse (French → English)
  3. Visualize all attention heads
  4. Add shared embeddings
  5. Implement T5-style text-to-text

15. Key Takeaways

Check Insight
Check Encoder = context encoder
Check Decoder = autoregressive + cross-attention
Check Cross-attention = conditioned generation
Check Mini-T5 works with 64-dim!
Check Used in T5, BART, MT

Full Copy-Paste Mini-T5

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.n_head = n_head
        self.d_k = n_embd // n_head
        self.Wq = nn.Linear(n_embd, n_embd)
        self.Wk = nn.Linear(n_embd, n_embd)
        self.Wv = nn.Linear(n_embd, n_embd)
        self.Wo = nn.Linear(n_embd, n_embd)
    def forward(self, Q, K, V, mask=None):
        B, T, C = Q.shape
        q = self.Wq(Q).view(B, T, self.n_head, self.d_k).transpose(1,2)
        k = self.Wk(K).view(B, T, self.n_head, self.d_k).transpose(1,2)
        v = self.Wv(V).view(B, T, self.n_head, self.d_k).transpose(1,2)
        att = (q @ k.transpose(-2,-1)) / (self.d_k**0.5)
        if mask is not None: att = att.masked_fill(~mask, -1e9)
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1,2).contiguous().view(B, T, C)
        return self.Wo(y), att

class CausalMultiHeadAttention(MultiHeadAttention):
    def forward(self, x, mask=None):
        B, T, C = x.shape
        causal_mask = torch.tril(torch.ones(T, T, device=x.device)).bool()
        if mask is not None: mask = mask & causal_mask
        else: mask = causal_mask
        return super().forward(x, x, x, mask)

class FeedForward(nn.Module):
    def __init__(self, n_embd): super().__init__(); self.net = nn.Sequential(nn.Linear(n_embd, n_embd*4), nn.GELU(), nn.Linear(n_embd*4, n_embd))
    def forward(self, x): return self.net(x)

# EncoderBlock, DecoderBlock, MiniT5 as above

Final Words

You just built T5 from scratch.
Encoder-Decoder = the original Transformer — still powers translation, summarization, and more.


End of Module
You now control both halves of the Transformer family.
Next: Pretrain a 125M model.

Last updated: Nov 13, 2025