Encoder-Decoder Transformers
Complete Module: Cross-Attention, Seq2Seq, Machine Translation, Mini-T5 (64-dim)
Encoder-Decoder Transformers
Complete Module: Cross-Attention, Seq2Seq, Machine Translation, Mini-T5 (64-dim)
Encoder-Decoder Transformers
Complete Module: Cross-Attention, Seq2Seq, Machine Translation, Mini-T5 (64-dim)
Module Objective
Master the Encoder-Decoder Transformer — cross-attention, seq2seq, machine translation, with full PyTorch implementation of a Mini-T5 (64-dim).
1. Encoder-Decoder vs Decoder-Only
| Architecture | Use Case | Attention Types |
|---|---|---|
| Decoder-Only | Text generation | Self-attention (causal) |
| Encoder-Decoder | Translation, Summarization | Self + Cross |
2. Encoder-Decoder Architecture
Input (src) → [Encoder] → Memory (K, V)
↘
Output (tgt) → [Decoder] → Cross-Attention(Q from tgt, K,V from src)
- Encoder: Bidirectional self-attention
- Decoder: Causal self-attention + cross-attention
3. Three Types of Attention
# 1. Encoder: Self-Attention (bidirectional)
attn_enc = MultiHeadAttention(enc_x, enc_x, enc_x)
# 2. Decoder: Self-Attention (causal)
attn_dec_self = MultiHeadAttention(dec_x, dec_x, dec_x, mask=causal_mask)
# 3. Decoder: Cross-Attention
attn_cross = MultiHeadAttention(dec_x, enc_x, enc_x) # Q=dec, K=V=enc
4. Full Encoder-Decoder Block
class EncoderBlock(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = MultiHeadAttention(n_embd, n_head)
self.ln2 = nn.LayerNorm(n_embd)
self.ff = FeedForward(n_embd)
def forward(self, x):
x = x + self.attn(self.ln1(x))[0]
x = x + self.ff(self.ln2(x))
return x
class DecoderBlock(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.self_attn = CausalMultiHeadAttention(n_embd, n_head)
self.ln2 = nn.LayerNorm(n_embd)
self.cross_attn = MultiHeadAttention(n_embd, n_head)
self.ln3 = nn.LayerNorm(n_embd)
self.ff = FeedForward(n_embd)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
# Self-attention (causal)
x = x + self.self_attn(self.ln1(x), mask=tgt_mask)[0]
# Cross-attention
x = x + self.cross_attn(self.ln2(x), enc_output, enc_output, mask=src_mask)[0]
# FFN
x = x + self.ff(self.ln3(x))
return x
5. Mini-T5 (64-dim) — Full Implementation
class MiniT5(nn.Module):
def __init__(self, src_vocab=1000, tgt_vocab=1000, n_embd=64, n_head=4, n_layer=3, max_len=128):
super().__init__()
self.n_embd = n_embd
self.src_emb = nn.Embedding(src_vocab, n_embd)
self.tgt_emb = nn.Embedding(tgt_vocab, n_embd)
self.pos_emb = nn.Embedding(max_len, n_embd)
self.encoder = nn.ModuleList([EncoderBlock(n_embd, n_head) for _ in range(n_layer)])
self.decoder = nn.ModuleList([DecoderBlock(n_embd, n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, tgt_vocab, bias=False)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Embedding)):
nn.init.normal_(m.weight, std=0.02)
def encode(self, src, src_mask=None):
x = self.src_emb(src) + self.pos_emb(torch.arange(src.size(1), device=src.device))
for block in self.encoder:
x = block(x)
return x # Memory for cross-attention
def decode(self, tgt, memory, src_mask=None, tgt_mask=None):
x = self.tgt_emb(tgt) + self.pos_emb(torch.arange(tgt.size(1), device=tgt.device))
for block in self.decoder:
x = block(x, memory, src_mask, tgt_mask)
x = self.ln_f(x)
return self.lm_head(x)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
memory = self.encode(src, src_mask)
logits = self.decode(tgt, memory, src_mask, tgt_mask)
return logits
6. Masks
def create_padding_mask(seq, pad_idx=0):
return (seq != pad_idx).unsqueeze(1).unsqueeze(2) # (B,1,1,S)
def create_causal_mask(seq_len):
return torch.triu(torch.ones(seq_len, seq_len), diagonal=1) == 0
7. Training: Teacher Forcing
# src: "hello" → [100, 200, 300, 0]
# tgt: "<s> bonjour </s>" → [1, 400, 500, 2]
# tgt_input = tgt[:, :-1], tgt_label = tgt[:, 1:]
logits = model(src, tgt_input, src_mask, tgt_mask)
loss = F.cross_entropy(logits.view(-1, tgt_vocab), tgt_label.view(-1))
8. Inference: Autoregressive Decoding
@torch.no_grad()
def generate(self, src, max_len=50, sos_idx=1, eos_idx=2):
memory = self.encode(src)
tgt = torch.tensor([[sos_idx]], device=src.device)
for _ in range(max_len):
logits = self.decode(tgt, memory)
next_token = logits[:, -1, :].argmax(-1, keepdim=True)
tgt = torch.cat([tgt, next_token], dim=1)
if next_token.item() == eos_idx:
break
return tgt
9. Mini Translation Dataset (English → French)
pairs = [
("hello", "bonjour"),
("thank you", "merci"),
("good morning", "bonjour"),
("how are you", "comment allez-vous"),
("i love you", "je t'aime"),
]
# Build vocab
src_vocab = {'<pad>':0, '<s>':1, '</s>':2}
tgt_vocab = {'<pad>':0, '<s>':1, '</s>':2}
for en, fr in pairs:
for w in en.split(): src_vocab.setdefault(w, len(src_vocab))
for w in fr.split(): tgt_vocab.setdefault(w, len(tgt_vocab))
# Encode
def encode_pair(en, fr):
src = [1] + [src_vocab[w] for w in en.split()] + [2]
tgt = [1] + [tgt_vocab[w] for w in fr.split()] + [2]
return torch.tensor(src), torch.tensor(tgt)
dataset = [encode_pair(en, fr) for en, fr in pairs]
10. Train Mini-T5
model = MiniT5(len(src_vocab), len(tgt_vocab), n_embd=64, n_head=4, n_layer=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(500):
total_loss = 0
for src, tgt in dataset:
src = src.unsqueeze(0)
tgt_input = tgt[:-1].unsqueeze(0)
tgt_label = tgt[1:].unsqueeze(0)
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
tgt_mask = create_causal_mask(tgt_input.size(1))
logits = model(src, tgt_input, src_mask, tgt_mask)
loss = F.cross_entropy(logits.view(-1, len(tgt_vocab)), tgt_label.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 100 == 0:
print(f"Epoch {epoch}, Loss: {total_loss/len(dataset):.4f}")
11. Generate Translation
src = torch.tensor([[1, src_vocab["hello"], 2]]) # <s> hello </s>
output = model.generate(src)
translated = [tgt_vocab_inv[i] for i in output[0].tolist() if i > 2]
print("Input:", "hello")
print("Output:", " ".join(translated))
# → "bonjour"
12. Visualization: Cross-Attention Heatmap
import matplotlib.pyplot as plt
import seaborn as sns
# Hook cross-attention
attn_weights = []
def hook(module, input, output):
attn_weights.append(output[1].detach())
handle = model.decoder[0].cross_attn.register_forward_hook(hook)
src = torch.tensor([[1, src_vocab["thank"], src_vocab["you"], 2]])
tgt = torch.tensor([[1]]) # <s>
model.generate(src)
handle.remove()
# Plot
sns.heatmap(attn_weights[0][0, 0].cpu(), annot=True, cmap="Blues",
xticklabels=["<s>", "thank", "you", "</s>"],
yticklabels=["<s>"])
plt.title("Cross-Attention: Decoder <s> → Encoder")
plt.show()
13. Summary Table
| Component | Encoder | Decoder |
|---|---|---|
| Self-Attention | Bidirectional | Causal |
| Cross-Attention | No | Yes (Q=dec, K=V=enc) |
| Mask | Padding | Padding + Causal |
| Output | Memory | Translation |
14. Practice Exercises
- Add beam search
- Train on reverse (French → English)
- Visualize all attention heads
- Add shared embeddings
- Implement T5-style text-to-text
15. Key Takeaways
| Check | Insight |
|---|---|
| Check | Encoder = context encoder |
| Check | Decoder = autoregressive + cross-attention |
| Check | Cross-attention = conditioned generation |
| Check | Mini-T5 works with 64-dim! |
| Check | Used in T5, BART, MT |
Full Copy-Paste Mini-T5
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.n_head = n_head
self.d_k = n_embd // n_head
self.Wq = nn.Linear(n_embd, n_embd)
self.Wk = nn.Linear(n_embd, n_embd)
self.Wv = nn.Linear(n_embd, n_embd)
self.Wo = nn.Linear(n_embd, n_embd)
def forward(self, Q, K, V, mask=None):
B, T, C = Q.shape
q = self.Wq(Q).view(B, T, self.n_head, self.d_k).transpose(1,2)
k = self.Wk(K).view(B, T, self.n_head, self.d_k).transpose(1,2)
v = self.Wv(V).view(B, T, self.n_head, self.d_k).transpose(1,2)
att = (q @ k.transpose(-2,-1)) / (self.d_k**0.5)
if mask is not None: att = att.masked_fill(~mask, -1e9)
att = F.softmax(att, dim=-1)
y = att @ v
y = y.transpose(1,2).contiguous().view(B, T, C)
return self.Wo(y), att
class CausalMultiHeadAttention(MultiHeadAttention):
def forward(self, x, mask=None):
B, T, C = x.shape
causal_mask = torch.tril(torch.ones(T, T, device=x.device)).bool()
if mask is not None: mask = mask & causal_mask
else: mask = causal_mask
return super().forward(x, x, x, mask)
class FeedForward(nn.Module):
def __init__(self, n_embd): super().__init__(); self.net = nn.Sequential(nn.Linear(n_embd, n_embd*4), nn.GELU(), nn.Linear(n_embd*4, n_embd))
def forward(self, x): return self.net(x)
# EncoderBlock, DecoderBlock, MiniT5 as above
Final Words
You just built T5 from scratch.
Encoder-Decoder = the original Transformer — still powers translation, summarization, and more.
End of Module
You now control both halves of the Transformer family.
Next: Pretrain a 125M model.
Encoder-Decoder Transformers
Complete Module: Cross-Attention, Seq2Seq, Machine Translation, Mini-T5 (64-dim)
Encoder-Decoder Transformers
Complete Module: Cross-Attention, Seq2Seq, Machine Translation, Mini-T5 (64-dim)
Encoder-Decoder Transformers
Complete Module: Cross-Attention, Seq2Seq, Machine Translation, Mini-T5 (64-dim)
Module Objective
Master the Encoder-Decoder Transformer — cross-attention, seq2seq, machine translation, with full PyTorch implementation of a Mini-T5 (64-dim).
1. Encoder-Decoder vs Decoder-Only
| Architecture | Use Case | Attention Types |
|---|---|---|
| Decoder-Only | Text generation | Self-attention (causal) |
| Encoder-Decoder | Translation, Summarization | Self + Cross |
2. Encoder-Decoder Architecture
Input (src) → [Encoder] → Memory (K, V)
↘
Output (tgt) → [Decoder] → Cross-Attention(Q from tgt, K,V from src)
- Encoder: Bidirectional self-attention
- Decoder: Causal self-attention + cross-attention
3. Three Types of Attention
# 1. Encoder: Self-Attention (bidirectional)
attn_enc = MultiHeadAttention(enc_x, enc_x, enc_x)
# 2. Decoder: Self-Attention (causal)
attn_dec_self = MultiHeadAttention(dec_x, dec_x, dec_x, mask=causal_mask)
# 3. Decoder: Cross-Attention
attn_cross = MultiHeadAttention(dec_x, enc_x, enc_x) # Q=dec, K=V=enc
4. Full Encoder-Decoder Block
class EncoderBlock(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = MultiHeadAttention(n_embd, n_head)
self.ln2 = nn.LayerNorm(n_embd)
self.ff = FeedForward(n_embd)
def forward(self, x):
x = x + self.attn(self.ln1(x))[0]
x = x + self.ff(self.ln2(x))
return x
class DecoderBlock(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.self_attn = CausalMultiHeadAttention(n_embd, n_head)
self.ln2 = nn.LayerNorm(n_embd)
self.cross_attn = MultiHeadAttention(n_embd, n_head)
self.ln3 = nn.LayerNorm(n_embd)
self.ff = FeedForward(n_embd)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
# Self-attention (causal)
x = x + self.self_attn(self.ln1(x), mask=tgt_mask)[0]
# Cross-attention
x = x + self.cross_attn(self.ln2(x), enc_output, enc_output, mask=src_mask)[0]
# FFN
x = x + self.ff(self.ln3(x))
return x
5. Mini-T5 (64-dim) — Full Implementation
class MiniT5(nn.Module):
def __init__(self, src_vocab=1000, tgt_vocab=1000, n_embd=64, n_head=4, n_layer=3, max_len=128):
super().__init__()
self.n_embd = n_embd
self.src_emb = nn.Embedding(src_vocab, n_embd)
self.tgt_emb = nn.Embedding(tgt_vocab, n_embd)
self.pos_emb = nn.Embedding(max_len, n_embd)
self.encoder = nn.ModuleList([EncoderBlock(n_embd, n_head) for _ in range(n_layer)])
self.decoder = nn.ModuleList([DecoderBlock(n_embd, n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, tgt_vocab, bias=False)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Embedding)):
nn.init.normal_(m.weight, std=0.02)
def encode(self, src, src_mask=None):
x = self.src_emb(src) + self.pos_emb(torch.arange(src.size(1), device=src.device))
for block in self.encoder:
x = block(x)
return x # Memory for cross-attention
def decode(self, tgt, memory, src_mask=None, tgt_mask=None):
x = self.tgt_emb(tgt) + self.pos_emb(torch.arange(tgt.size(1), device=tgt.device))
for block in self.decoder:
x = block(x, memory, src_mask, tgt_mask)
x = self.ln_f(x)
return self.lm_head(x)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
memory = self.encode(src, src_mask)
logits = self.decode(tgt, memory, src_mask, tgt_mask)
return logits
6. Masks
def create_padding_mask(seq, pad_idx=0):
return (seq != pad_idx).unsqueeze(1).unsqueeze(2) # (B,1,1,S)
def create_causal_mask(seq_len):
return torch.triu(torch.ones(seq_len, seq_len), diagonal=1) == 0
7. Training: Teacher Forcing
# src: "hello" → [100, 200, 300, 0]
# tgt: "<s> bonjour </s>" → [1, 400, 500, 2]
# tgt_input = tgt[:, :-1], tgt_label = tgt[:, 1:]
logits = model(src, tgt_input, src_mask, tgt_mask)
loss = F.cross_entropy(logits.view(-1, tgt_vocab), tgt_label.view(-1))
8. Inference: Autoregressive Decoding
@torch.no_grad()
def generate(self, src, max_len=50, sos_idx=1, eos_idx=2):
memory = self.encode(src)
tgt = torch.tensor([[sos_idx]], device=src.device)
for _ in range(max_len):
logits = self.decode(tgt, memory)
next_token = logits[:, -1, :].argmax(-1, keepdim=True)
tgt = torch.cat([tgt, next_token], dim=1)
if next_token.item() == eos_idx:
break
return tgt
9. Mini Translation Dataset (English → French)
pairs = [
("hello", "bonjour"),
("thank you", "merci"),
("good morning", "bonjour"),
("how are you", "comment allez-vous"),
("i love you", "je t'aime"),
]
# Build vocab
src_vocab = {'<pad>':0, '<s>':1, '</s>':2}
tgt_vocab = {'<pad>':0, '<s>':1, '</s>':2}
for en, fr in pairs:
for w in en.split(): src_vocab.setdefault(w, len(src_vocab))
for w in fr.split(): tgt_vocab.setdefault(w, len(tgt_vocab))
# Encode
def encode_pair(en, fr):
src = [1] + [src_vocab[w] for w in en.split()] + [2]
tgt = [1] + [tgt_vocab[w] for w in fr.split()] + [2]
return torch.tensor(src), torch.tensor(tgt)
dataset = [encode_pair(en, fr) for en, fr in pairs]
10. Train Mini-T5
model = MiniT5(len(src_vocab), len(tgt_vocab), n_embd=64, n_head=4, n_layer=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(500):
total_loss = 0
for src, tgt in dataset:
src = src.unsqueeze(0)
tgt_input = tgt[:-1].unsqueeze(0)
tgt_label = tgt[1:].unsqueeze(0)
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
tgt_mask = create_causal_mask(tgt_input.size(1))
logits = model(src, tgt_input, src_mask, tgt_mask)
loss = F.cross_entropy(logits.view(-1, len(tgt_vocab)), tgt_label.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 100 == 0:
print(f"Epoch {epoch}, Loss: {total_loss/len(dataset):.4f}")
11. Generate Translation
src = torch.tensor([[1, src_vocab["hello"], 2]]) # <s> hello </s>
output = model.generate(src)
translated = [tgt_vocab_inv[i] for i in output[0].tolist() if i > 2]
print("Input:", "hello")
print("Output:", " ".join(translated))
# → "bonjour"
12. Visualization: Cross-Attention Heatmap
import matplotlib.pyplot as plt
import seaborn as sns
# Hook cross-attention
attn_weights = []
def hook(module, input, output):
attn_weights.append(output[1].detach())
handle = model.decoder[0].cross_attn.register_forward_hook(hook)
src = torch.tensor([[1, src_vocab["thank"], src_vocab["you"], 2]])
tgt = torch.tensor([[1]]) # <s>
model.generate(src)
handle.remove()
# Plot
sns.heatmap(attn_weights[0][0, 0].cpu(), annot=True, cmap="Blues",
xticklabels=["<s>", "thank", "you", "</s>"],
yticklabels=["<s>"])
plt.title("Cross-Attention: Decoder <s> → Encoder")
plt.show()
13. Summary Table
| Component | Encoder | Decoder |
|---|---|---|
| Self-Attention | Bidirectional | Causal |
| Cross-Attention | No | Yes (Q=dec, K=V=enc) |
| Mask | Padding | Padding + Causal |
| Output | Memory | Translation |
14. Practice Exercises
- Add beam search
- Train on reverse (French → English)
- Visualize all attention heads
- Add shared embeddings
- Implement T5-style text-to-text
15. Key Takeaways
| Check | Insight |
|---|---|
| Check | Encoder = context encoder |
| Check | Decoder = autoregressive + cross-attention |
| Check | Cross-attention = conditioned generation |
| Check | Mini-T5 works with 64-dim! |
| Check | Used in T5, BART, MT |
Full Copy-Paste Mini-T5
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.n_head = n_head
self.d_k = n_embd // n_head
self.Wq = nn.Linear(n_embd, n_embd)
self.Wk = nn.Linear(n_embd, n_embd)
self.Wv = nn.Linear(n_embd, n_embd)
self.Wo = nn.Linear(n_embd, n_embd)
def forward(self, Q, K, V, mask=None):
B, T, C = Q.shape
q = self.Wq(Q).view(B, T, self.n_head, self.d_k).transpose(1,2)
k = self.Wk(K).view(B, T, self.n_head, self.d_k).transpose(1,2)
v = self.Wv(V).view(B, T, self.n_head, self.d_k).transpose(1,2)
att = (q @ k.transpose(-2,-1)) / (self.d_k**0.5)
if mask is not None: att = att.masked_fill(~mask, -1e9)
att = F.softmax(att, dim=-1)
y = att @ v
y = y.transpose(1,2).contiguous().view(B, T, C)
return self.Wo(y), att
class CausalMultiHeadAttention(MultiHeadAttention):
def forward(self, x, mask=None):
B, T, C = x.shape
causal_mask = torch.tril(torch.ones(T, T, device=x.device)).bool()
if mask is not None: mask = mask & causal_mask
else: mask = causal_mask
return super().forward(x, x, x, mask)
class FeedForward(nn.Module):
def __init__(self, n_embd): super().__init__(); self.net = nn.Sequential(nn.Linear(n_embd, n_embd*4), nn.GELU(), nn.Linear(n_embd*4, n_embd))
def forward(self, x): return self.net(x)
# EncoderBlock, DecoderBlock, MiniT5 as above
Final Words
You just built T5 from scratch.
Encoder-Decoder = the original Transformer — still powers translation, summarization, and more.
End of Module
You now control both halves of the Transformer family.
Next: Pretrain a 125M model.