Scaling Laws & Optimization
Complete Module: Big-O, Parallelism, FlashAttention, LoRA
Scaling Laws & Optimization
Complete Module: Big-O, Parallelism, FlashAttention, LoRA
Scaling Laws & Optimization
Complete Module: Big-O, Parallelism, FlashAttention, LoRA
Module Objective
Master LLM scaling — Chinchilla laws, Big-O complexity, GPU parallelism, FlashAttention, LoRA — with math, code, and 100x efficiency gains.
1. Chinchilla Scaling Laws (2022)
"Optimal training: balance model size and data"
# Optimal parameters for given compute
def chinchilla_optimal_params(compute):
return 0.074 * compute**0.73 # ~70B for 1.4T tokens
# Optimal tokens
def chinchilla_optimal_tokens(compute):
return 19.3 * compute**0.27 # ~1.4T for 70B model
| Model | Params | Tokens | Compute | Undertrained? |
|---|---|---|---|---|
| GPT-3 | 175B | 300B | 3.7e23 | Yes |
| Chinchilla | 70B | 1.4T | 3.7e23 | Optimal |
Result: 70B > 175B on same compute
2. Big-O Complexity of Transformers
| Operation | Time | Memory |
|---|---|---|
| Attention | $ O(N^2 d) $ | $ O(N^2) $ |
| FFN | $ O(N d^2) $ | $ O(N d) $ |
| Total per layer | $ O(N^2 d + N d^2) $ | $ O(N^2 + N d) $ |
| L layers | $ O(L N^2 d) $ | $ O(L N^2) $ |
Bottleneck: $ N^2 $ attention matrix
3. GPU Parallelism: Data, Tensor, Pipeline
Data Parallel (DP):
8 GPUs → 8x batch → same model
Tensor Parallel (TP):
Layer split across 4 GPUs → W_q on GPU0, W_k on GPU1
Pipeline Parallel (PP):
Layers 1–4 on GPU0, 5–8 on GPU1
Megatron-LM: TP + PP → 1T params
4. FlashAttention: O(N) Memory, 2–4x Faster
Problem: Standard Attention
attn = softmax(Q @ K.T / sqrt(d)) @ V
# → Materialize N×N matrix → O(N²) memory
FlashAttention: No materialization
# Online softmax + tiling
for i in blocks:
Q_block = Q[i]
for j in blocks:
K_block, V_block = K[j], V[j]
S = Q_block @ K_block.T
P = softmax(S)
O += P @ V_block
Memory: $ O(N) $
Speed: 2–4x faster, 15% less memory
from flash_attn import flash_attention
attn_output = flash_attention(q, k, v, causal=True)
5. LoRA: Train 0.1% of Parameters
"Freeze weights, train low-rank adapters"
W = W₀ + ΔW
ΔW = B A # B: (d, r), A: (r, k) → r << d
LoRA Injection
class LoRALinear(nn.Module):
def __init__(self, linear, rank=8):
super().__init__()
self.linear = linear
d = linear.in_features
self.A = nn.Parameter(torch.randn(rank, d) * 0.01)
self.B = nn.Parameter(torch.zeros(d, rank))
def forward(self, x):
return self.linear(x) + (x @ self.A.T @ self.B.T)
Params: $ 2 r d $ vs $ d k $
Example: $ d=4096, r=8 → 0.2% $ of weight
6. Full LoRA + FlashAttention Training
from transformers import AutoModelForCausalLM
import peft
from flash_attn import flash_attention
# Load base model
model = AutoModelForCausalLM.from_pretrained("gpt2")
# Add LoRA
lora_config = peft.LoraConfig(
r=8, lora_alpha=32, target_modules=["c_attn", "c_proj"], lora_dropout=0.1
)
model = peft.get_peft_model(model, lora_config)
# Use FlashAttention
def forward_with_flash(self, x):
q, k, v = self.W_q(x), self.W_k(x), self.W_v(x)
return flash_attention(q, k, v, causal=True)
# Monkey patch
model.transformer.h[0].attn.forward = forward_with_flash
7. Compute Scaling: FLOPs
def transformer_flops(batch, seq_len, d_model, layers, vocab):
# Embedding
flops = batch * seq_len * vocab * d_model
# Per layer
attn = 2 * batch * seq_len**2 * d_model
ffn = 8 * batch * seq_len * d_model**2
flops += layers * (attn + ffn)
# Output
flops += batch * seq_len * d_model * vocab
return flops
print(f"GPT-3 175B: {transformer_flops(1, 2048, 12288, 96, 50257):.2e} FLOPs")
# → 3.7e23 FLOPs
8. Parallelism in Code
# Tensor Parallel (simplified)
class TensorParallelLinear(nn.Module):
def __init__(self, in_features, out_features, world_size):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features//world_size, in_features))
self.world_size = world_size
def forward(self, x):
out = x @ self.weight.t()
# All-gather across GPUs
return all_gather(out, dim=-1)
9. Optimization: AdamW + Gradient Checkpointing
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
# Gradient Checkpointing: trade compute for memory
model = torch.utils.checkpoint.checkpoint_sequential(model, segments=4)
10. Summary Table
| Technique | Speed | Memory | Params Trained |
|---|---|---|---|
| FlashAttention | 2–4x | -80% | 100% |
| LoRA | 1x | -99% | 0.1% |
| Tensor Parallel | 8x (8 GPUs) | — | 100% |
| Gradient Checkpoint | 0.7x | -70% | 100% |
11. Practice Exercises
- Train LoRA on TinyShakespeare
- Benchmark FlashAttention vs standard
- Plot Chinchilla curve
- Implement pipeline parallelism
- Combine LoRA + FlashAttention
12. Key Takeaways
| Check | Insight |
|---|---|
| Check | Chinchilla: 70B > 175B |
| Check | Attention = O(N²) |
| Check | FlashAttention = O(N) memory |
| Check | LoRA = 0.1% trainable params |
| Check | Scale efficiently |
Full Copy-Paste: LoRA + FlashAttention
!pip install flash-attn peft transformers
import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from flash_attn import flash_attention
# Load model
model = AutoModelForCausalLM.from_pretrained("gpt2")
# LoRA
config = LoraConfig(r=8, lora_alpha=32, target_modules=["c_attn", "c_proj"])
model = get_peft_model(model, config)
# Replace attention
def flash_forward(self, x):
q, k, v = self.c_attn(x).chunk(3, dim=-1)
return self.c_proj(flash_attention(q, k, v, causal=True))
# Patch first layer
model.transformer.h[0].attn.forward = flash_forward.__get__(model.transformer.h[0].attn)
print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
Final Words
You now train LLMs like DeepMind, Meta, Google.
- Chinchilla-optimal
- FlashAttention-fast
- LoRA-efficient
- Scalable to 1T
End of Module
You scale like the pros — efficient, fast, optimal.
Next: Build a 7B model.
Scaling Laws & Optimization
Complete Module: Big-O, Parallelism, FlashAttention, LoRA
Scaling Laws & Optimization
Complete Module: Big-O, Parallelism, FlashAttention, LoRA
Scaling Laws & Optimization
Complete Module: Big-O, Parallelism, FlashAttention, LoRA
Module Objective
Master LLM scaling — Chinchilla laws, Big-O complexity, GPU parallelism, FlashAttention, LoRA — with math, code, and 100x efficiency gains.
1. Chinchilla Scaling Laws (2022)
"Optimal training: balance model size and data"
# Optimal parameters for given compute
def chinchilla_optimal_params(compute):
return 0.074 * compute**0.73 # ~70B for 1.4T tokens
# Optimal tokens
def chinchilla_optimal_tokens(compute):
return 19.3 * compute**0.27 # ~1.4T for 70B model
| Model | Params | Tokens | Compute | Undertrained? |
|---|---|---|---|---|
| GPT-3 | 175B | 300B | 3.7e23 | Yes |
| Chinchilla | 70B | 1.4T | 3.7e23 | Optimal |
Result: 70B > 175B on same compute
2. Big-O Complexity of Transformers
| Operation | Time | Memory |
|---|---|---|
| Attention | $ O(N^2 d) $ | $ O(N^2) $ |
| FFN | $ O(N d^2) $ | $ O(N d) $ |
| Total per layer | $ O(N^2 d + N d^2) $ | $ O(N^2 + N d) $ |
| L layers | $ O(L N^2 d) $ | $ O(L N^2) $ |
Bottleneck: $ N^2 $ attention matrix
3. GPU Parallelism: Data, Tensor, Pipeline
Data Parallel (DP):
8 GPUs → 8x batch → same model
Tensor Parallel (TP):
Layer split across 4 GPUs → W_q on GPU0, W_k on GPU1
Pipeline Parallel (PP):
Layers 1–4 on GPU0, 5–8 on GPU1
Megatron-LM: TP + PP → 1T params
4. FlashAttention: O(N) Memory, 2–4x Faster
Problem: Standard Attention
attn = softmax(Q @ K.T / sqrt(d)) @ V
# → Materialize N×N matrix → O(N²) memory
FlashAttention: No materialization
# Online softmax + tiling
for i in blocks:
Q_block = Q[i]
for j in blocks:
K_block, V_block = K[j], V[j]
S = Q_block @ K_block.T
P = softmax(S)
O += P @ V_block
Memory: $ O(N) $
Speed: 2–4x faster, 15% less memory
from flash_attn import flash_attention
attn_output = flash_attention(q, k, v, causal=True)
5. LoRA: Train 0.1% of Parameters
"Freeze weights, train low-rank adapters"
W = W₀ + ΔW
ΔW = B A # B: (d, r), A: (r, k) → r << d
LoRA Injection
class LoRALinear(nn.Module):
def __init__(self, linear, rank=8):
super().__init__()
self.linear = linear
d = linear.in_features
self.A = nn.Parameter(torch.randn(rank, d) * 0.01)
self.B = nn.Parameter(torch.zeros(d, rank))
def forward(self, x):
return self.linear(x) + (x @ self.A.T @ self.B.T)
Params: $ 2 r d $ vs $ d k $
Example: $ d=4096, r=8 → 0.2% $ of weight
6. Full LoRA + FlashAttention Training
from transformers import AutoModelForCausalLM
import peft
from flash_attn import flash_attention
# Load base model
model = AutoModelForCausalLM.from_pretrained("gpt2")
# Add LoRA
lora_config = peft.LoraConfig(
r=8, lora_alpha=32, target_modules=["c_attn", "c_proj"], lora_dropout=0.1
)
model = peft.get_peft_model(model, lora_config)
# Use FlashAttention
def forward_with_flash(self, x):
q, k, v = self.W_q(x), self.W_k(x), self.W_v(x)
return flash_attention(q, k, v, causal=True)
# Monkey patch
model.transformer.h[0].attn.forward = forward_with_flash
7. Compute Scaling: FLOPs
def transformer_flops(batch, seq_len, d_model, layers, vocab):
# Embedding
flops = batch * seq_len * vocab * d_model
# Per layer
attn = 2 * batch * seq_len**2 * d_model
ffn = 8 * batch * seq_len * d_model**2
flops += layers * (attn + ffn)
# Output
flops += batch * seq_len * d_model * vocab
return flops
print(f"GPT-3 175B: {transformer_flops(1, 2048, 12288, 96, 50257):.2e} FLOPs")
# → 3.7e23 FLOPs
8. Parallelism in Code
# Tensor Parallel (simplified)
class TensorParallelLinear(nn.Module):
def __init__(self, in_features, out_features, world_size):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features//world_size, in_features))
self.world_size = world_size
def forward(self, x):
out = x @ self.weight.t()
# All-gather across GPUs
return all_gather(out, dim=-1)
9. Optimization: AdamW + Gradient Checkpointing
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
# Gradient Checkpointing: trade compute for memory
model = torch.utils.checkpoint.checkpoint_sequential(model, segments=4)
10. Summary Table
| Technique | Speed | Memory | Params Trained |
|---|---|---|---|
| FlashAttention | 2–4x | -80% | 100% |
| LoRA | 1x | -99% | 0.1% |
| Tensor Parallel | 8x (8 GPUs) | — | 100% |
| Gradient Checkpoint | 0.7x | -70% | 100% |
11. Practice Exercises
- Train LoRA on TinyShakespeare
- Benchmark FlashAttention vs standard
- Plot Chinchilla curve
- Implement pipeline parallelism
- Combine LoRA + FlashAttention
12. Key Takeaways
| Check | Insight |
|---|---|
| Check | Chinchilla: 70B > 175B |
| Check | Attention = O(N²) |
| Check | FlashAttention = O(N) memory |
| Check | LoRA = 0.1% trainable params |
| Check | Scale efficiently |
Full Copy-Paste: LoRA + FlashAttention
!pip install flash-attn peft transformers
import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from flash_attn import flash_attention
# Load model
model = AutoModelForCausalLM.from_pretrained("gpt2")
# LoRA
config = LoraConfig(r=8, lora_alpha=32, target_modules=["c_attn", "c_proj"])
model = get_peft_model(model, config)
# Replace attention
def flash_forward(self, x):
q, k, v = self.c_attn(x).chunk(3, dim=-1)
return self.c_proj(flash_attention(q, k, v, causal=True))
# Patch first layer
model.transformer.h[0].attn.forward = flash_forward.__get__(model.transformer.h[0].attn)
print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
Final Words
You now train LLMs like DeepMind, Meta, Google.
- Chinchilla-optimal
- FlashAttention-fast
- LoRA-efficient
- Scalable to 1T
End of Module
You scale like the pros — efficient, fast, optimal.
Next: Build a 7B model.