Scaling Laws & Optimization

Complete Module: Big-O, Parallelism, FlashAttention, LoRA

Scaling Laws & Optimization

Complete Module: Big-O, Parallelism, FlashAttention, LoRA

Scaling Laws & Optimization

Complete Module: Big-O, Parallelism, FlashAttention, LoRA


Module Objective

Master LLM scalingChinchilla laws, Big-O complexity, GPU parallelism, FlashAttention, LoRA — with math, code, and 100x efficiency gains.


1. Chinchilla Scaling Laws (2022)

"Optimal training: balance model size and data"

# Optimal parameters for given compute
def chinchilla_optimal_params(compute):
    return 0.074 * compute**0.73  # ~70B for 1.4T tokens

# Optimal tokens
def chinchilla_optimal_tokens(compute):
    return 19.3 * compute**0.27  # ~1.4T for 70B model
Model Params Tokens Compute Undertrained?
GPT-3 175B 300B 3.7e23 Yes
Chinchilla 70B 1.4T 3.7e23 Optimal

Result: 70B > 175B on same compute


2. Big-O Complexity of Transformers

Operation Time Memory
Attention $ O(N^2 d) $ $ O(N^2) $
FFN $ O(N d^2) $ $ O(N d) $
Total per layer $ O(N^2 d + N d^2) $ $ O(N^2 + N d) $
L layers $ O(L N^2 d) $ $ O(L N^2) $

Bottleneck: $ N^2 $ attention matrix


3. GPU Parallelism: Data, Tensor, Pipeline

Data Parallel (DP): 
  8 GPUs → 8x batch → same model

Tensor Parallel (TP): 
  Layer split across 4 GPUs → W_q on GPU0, W_k on GPU1

Pipeline Parallel (PP): 
  Layers 1–4 on GPU0, 5–8 on GPU1

Megatron-LM: TP + PP → 1T params


4. FlashAttention: O(N) Memory, 2–4x Faster

Problem: Standard Attention

attn = softmax(Q @ K.T / sqrt(d)) @ V
# → Materialize N×N matrix → O(N²) memory

FlashAttention: No materialization

# Online softmax + tiling
for i in blocks:
    Q_block = Q[i]
    for j in blocks:
        K_block, V_block = K[j], V[j]
        S = Q_block @ K_block.T
        P = softmax(S)
        O += P @ V_block

Memory: $ O(N) $
Speed: 2–4x faster, 15% less memory

from flash_attn import flash_attention

attn_output = flash_attention(q, k, v, causal=True)

5. LoRA: Train 0.1% of Parameters

"Freeze weights, train low-rank adapters"

W = W₀ + ΔW
ΔW = B A    # B: (d, r), A: (r, k) → r << d

LoRA Injection

class LoRALinear(nn.Module):
    def __init__(self, linear, rank=8):
        super().__init__()
        self.linear = linear
        d = linear.in_features
        self.A = nn.Parameter(torch.randn(rank, d) * 0.01)
        self.B = nn.Parameter(torch.zeros(d, rank))

    def forward(self, x):
        return self.linear(x) + (x @ self.A.T @ self.B.T)

Params: $ 2 r d $ vs $ d k $
Example: $ d=4096, r=8 → 0.2% $ of weight


6. Full LoRA + FlashAttention Training

from transformers import AutoModelForCausalLM
import peft
from flash_attn import flash_attention

# Load base model
model = AutoModelForCausalLM.from_pretrained("gpt2")

# Add LoRA
lora_config = peft.LoraConfig(
    r=8, lora_alpha=32, target_modules=["c_attn", "c_proj"], lora_dropout=0.1
)
model = peft.get_peft_model(model, lora_config)

# Use FlashAttention
def forward_with_flash(self, x):
    q, k, v = self.W_q(x), self.W_k(x), self.W_v(x)
    return flash_attention(q, k, v, causal=True)

# Monkey patch
model.transformer.h[0].attn.forward = forward_with_flash

7. Compute Scaling: FLOPs

def transformer_flops(batch, seq_len, d_model, layers, vocab):
    # Embedding
    flops = batch * seq_len * vocab * d_model

    # Per layer
    attn = 2 * batch * seq_len**2 * d_model
    ffn = 8 * batch * seq_len * d_model**2
    flops += layers * (attn + ffn)

    # Output
    flops += batch * seq_len * d_model * vocab
    return flops

print(f"GPT-3 175B: {transformer_flops(1, 2048, 12288, 96, 50257):.2e} FLOPs")
# → 3.7e23 FLOPs

8. Parallelism in Code

# Tensor Parallel (simplified)
class TensorParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features//world_size, in_features))
        self.world_size = world_size

    def forward(self, x):
        out = x @ self.weight.t()
        # All-gather across GPUs
        return all_gather(out, dim=-1)

9. Optimization: AdamW + Gradient Checkpointing

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)

# Gradient Checkpointing: trade compute for memory
model = torch.utils.checkpoint.checkpoint_sequential(model, segments=4)

10. Summary Table

Technique Speed Memory Params Trained
FlashAttention 2–4x -80% 100%
LoRA 1x -99% 0.1%
Tensor Parallel 8x (8 GPUs) 100%
Gradient Checkpoint 0.7x -70% 100%

11. Practice Exercises

  1. Train LoRA on TinyShakespeare
  2. Benchmark FlashAttention vs standard
  3. Plot Chinchilla curve
  4. Implement pipeline parallelism
  5. Combine LoRA + FlashAttention

12. Key Takeaways

Check Insight
Check Chinchilla: 70B > 175B
Check Attention = O(N²)
Check FlashAttention = O(N) memory
Check LoRA = 0.1% trainable params
Check Scale efficiently

Full Copy-Paste: LoRA + FlashAttention

!pip install flash-attn peft transformers

import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from flash_attn import flash_attention

# Load model
model = AutoModelForCausalLM.from_pretrained("gpt2")

# LoRA
config = LoraConfig(r=8, lora_alpha=32, target_modules=["c_attn", "c_proj"])
model = get_peft_model(model, config)

# Replace attention
def flash_forward(self, x):
    q, k, v = self.c_attn(x).chunk(3, dim=-1)
    return self.c_proj(flash_attention(q, k, v, causal=True))

# Patch first layer
model.transformer.h[0].attn.forward = flash_forward.__get__(model.transformer.h[0].attn)

print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Final Words

You now train LLMs like DeepMind, Meta, Google.
- Chinchilla-optimal
- FlashAttention-fast
- LoRA-efficient
- Scalable to 1T


End of Module
You scale like the pros — efficient, fast, optimal.
Next: Build a 7B model.

Last updated: Nov 13, 2025

Scaling Laws & Optimization

Complete Module: Big-O, Parallelism, FlashAttention, LoRA

Scaling Laws & Optimization

Complete Module: Big-O, Parallelism, FlashAttention, LoRA

Scaling Laws & Optimization

Complete Module: Big-O, Parallelism, FlashAttention, LoRA


Module Objective

Master LLM scalingChinchilla laws, Big-O complexity, GPU parallelism, FlashAttention, LoRA — with math, code, and 100x efficiency gains.


1. Chinchilla Scaling Laws (2022)

"Optimal training: balance model size and data"

# Optimal parameters for given compute
def chinchilla_optimal_params(compute):
    return 0.074 * compute**0.73  # ~70B for 1.4T tokens

# Optimal tokens
def chinchilla_optimal_tokens(compute):
    return 19.3 * compute**0.27  # ~1.4T for 70B model
Model Params Tokens Compute Undertrained?
GPT-3 175B 300B 3.7e23 Yes
Chinchilla 70B 1.4T 3.7e23 Optimal

Result: 70B > 175B on same compute


2. Big-O Complexity of Transformers

Operation Time Memory
Attention $ O(N^2 d) $ $ O(N^2) $
FFN $ O(N d^2) $ $ O(N d) $
Total per layer $ O(N^2 d + N d^2) $ $ O(N^2 + N d) $
L layers $ O(L N^2 d) $ $ O(L N^2) $

Bottleneck: $ N^2 $ attention matrix


3. GPU Parallelism: Data, Tensor, Pipeline

Data Parallel (DP): 
  8 GPUs → 8x batch → same model

Tensor Parallel (TP): 
  Layer split across 4 GPUs → W_q on GPU0, W_k on GPU1

Pipeline Parallel (PP): 
  Layers 1–4 on GPU0, 5–8 on GPU1

Megatron-LM: TP + PP → 1T params


4. FlashAttention: O(N) Memory, 2–4x Faster

Problem: Standard Attention

attn = softmax(Q @ K.T / sqrt(d)) @ V
# → Materialize N×N matrix → O(N²) memory

FlashAttention: No materialization

# Online softmax + tiling
for i in blocks:
    Q_block = Q[i]
    for j in blocks:
        K_block, V_block = K[j], V[j]
        S = Q_block @ K_block.T
        P = softmax(S)
        O += P @ V_block

Memory: $ O(N) $
Speed: 2–4x faster, 15% less memory

from flash_attn import flash_attention

attn_output = flash_attention(q, k, v, causal=True)

5. LoRA: Train 0.1% of Parameters

"Freeze weights, train low-rank adapters"

W = W₀ + ΔW
ΔW = B A    # B: (d, r), A: (r, k) → r << d

LoRA Injection

class LoRALinear(nn.Module):
    def __init__(self, linear, rank=8):
        super().__init__()
        self.linear = linear
        d = linear.in_features
        self.A = nn.Parameter(torch.randn(rank, d) * 0.01)
        self.B = nn.Parameter(torch.zeros(d, rank))

    def forward(self, x):
        return self.linear(x) + (x @ self.A.T @ self.B.T)

Params: $ 2 r d $ vs $ d k $
Example: $ d=4096, r=8 → 0.2% $ of weight


6. Full LoRA + FlashAttention Training

from transformers import AutoModelForCausalLM
import peft
from flash_attn import flash_attention

# Load base model
model = AutoModelForCausalLM.from_pretrained("gpt2")

# Add LoRA
lora_config = peft.LoraConfig(
    r=8, lora_alpha=32, target_modules=["c_attn", "c_proj"], lora_dropout=0.1
)
model = peft.get_peft_model(model, lora_config)

# Use FlashAttention
def forward_with_flash(self, x):
    q, k, v = self.W_q(x), self.W_k(x), self.W_v(x)
    return flash_attention(q, k, v, causal=True)

# Monkey patch
model.transformer.h[0].attn.forward = forward_with_flash

7. Compute Scaling: FLOPs

def transformer_flops(batch, seq_len, d_model, layers, vocab):
    # Embedding
    flops = batch * seq_len * vocab * d_model

    # Per layer
    attn = 2 * batch * seq_len**2 * d_model
    ffn = 8 * batch * seq_len * d_model**2
    flops += layers * (attn + ffn)

    # Output
    flops += batch * seq_len * d_model * vocab
    return flops

print(f"GPT-3 175B: {transformer_flops(1, 2048, 12288, 96, 50257):.2e} FLOPs")
# → 3.7e23 FLOPs

8. Parallelism in Code

# Tensor Parallel (simplified)
class TensorParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features//world_size, in_features))
        self.world_size = world_size

    def forward(self, x):
        out = x @ self.weight.t()
        # All-gather across GPUs
        return all_gather(out, dim=-1)

9. Optimization: AdamW + Gradient Checkpointing

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)

# Gradient Checkpointing: trade compute for memory
model = torch.utils.checkpoint.checkpoint_sequential(model, segments=4)

10. Summary Table

Technique Speed Memory Params Trained
FlashAttention 2–4x -80% 100%
LoRA 1x -99% 0.1%
Tensor Parallel 8x (8 GPUs) 100%
Gradient Checkpoint 0.7x -70% 100%

11. Practice Exercises

  1. Train LoRA on TinyShakespeare
  2. Benchmark FlashAttention vs standard
  3. Plot Chinchilla curve
  4. Implement pipeline parallelism
  5. Combine LoRA + FlashAttention

12. Key Takeaways

Check Insight
Check Chinchilla: 70B > 175B
Check Attention = O(N²)
Check FlashAttention = O(N) memory
Check LoRA = 0.1% trainable params
Check Scale efficiently

Full Copy-Paste: LoRA + FlashAttention

!pip install flash-attn peft transformers

import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from flash_attn import flash_attention

# Load model
model = AutoModelForCausalLM.from_pretrained("gpt2")

# LoRA
config = LoraConfig(r=8, lora_alpha=32, target_modules=["c_attn", "c_proj"])
model = get_peft_model(model, config)

# Replace attention
def flash_forward(self, x):
    q, k, v = self.c_attn(x).chunk(3, dim=-1)
    return self.c_proj(flash_attention(q, k, v, causal=True))

# Patch first layer
model.transformer.h[0].attn.forward = flash_forward.__get__(model.transformer.h[0].attn)

print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Final Words

You now train LLMs like DeepMind, Meta, Google.
- Chinchilla-optimal
- FlashAttention-fast
- LoRA-efficient
- Scalable to 1T


End of Module
You scale like the pros — efficient, fast, optimal.
Next: Build a 7B model.

Last updated: Nov 13, 2025