Why Overparameterized Models Don't Always Overfit

Modern deep learning thrives on models with far more parameters than training samples — a regime classical ML theory says should lead to catastrophic overfitting. Yet, models like ResNet-18 (11.2M parameters) generalize remarkably well on datasets like Fashion-MNIST (60K samples). This paradox is explained by the double descent curve: as model size increases, test error first rises (the interpolation threshold), then drops again in the overparameterized regime.

The key insight? Not all minima are equal. Two minima with identical loss values can have vastly different generalization properties depending on their local geometry. Flat minima (wide valleys) correlate with better generalization than sharp minima (narrow valleys). The Sharpness-Aware Minimization (SAM) optimizer is designed to explicitly seek flat minima, making it a powerful tool for improving model generalization.

For a broader perspective on how AI agents are redefining platform interactions, check out this analysis of Cloudflare's Agent Lee.

Python code editor showing SAM optimizer implementation with PyTorch Programming Illustration

SAM Algorithm: Step-by-Step with PyTorch Code

SAM modifies the standard optimization loop by adding an adversarial perturbation step before each weight update. Here's the core algorithm:

  1. Compute gradient at current weights w
  2. Find perturbation ε that maximizes loss within a ball of radius ρ
  3. Compute gradient at perturbed point w + ε
  4. Update weights using that gradient

SAM Optimizer Class

import torch
from torch.optim import Optimizer

class SAM(Optimizer):
    """
    Sharpness-Aware Minimization optimizer.
    Args:
        params: model parameters
        base_optimizer: underlying optimizer (e.g., SGD, Adam)
        rho: perturbation radius (hyperparameter)
    """
    def __init__(self, params, base_optimizer, rho=0.05):
        defaults = dict(rho=rho)
        super(SAM, self).__init__(params, defaults)
        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        # Compute adversarial perturbation for each parameter group
        for group in self.param_groups:
            scale = group['rho'] / (
                sum(p.grad.norm().item() ** 2 for p in group['params'] if p.grad is not None) ** 0.5 + 1e-12
            )
            for p in group['params']:
                if p.grad is None:
                    continue
                e_w = p.grad * scale  # perturbation direction
                p.add_(e_w)  # perturb weights
                self.state[p]['e_w'] = e_w  # store for second step
        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        # Restore original weights and apply base optimizer update
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                p.sub_(self.state[p]['e_w'])  # restore original weights
        self.base_optimizer.step()  # base optimizer update
        if zero_grad:
            self.zero_grad()

Training Loop with BatchNorm Fix

BatchNorm layers update running statistics during forward passes. Since SAM uses two forward passes per iteration, we must disable running statistics during the second pass to avoid corrupting the statistics with perturbed weights.

import torch.nn as nn

def disable_bn_stats(model):
    """Disable BatchNorm running stats update during forward pass."""
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            module.track_running_stats = False

def enable_bn_stats(model):
    """Re-enable BatchNorm running stats update after perturbation pass."""
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            module.track_running_stats = True

# Training loop example
model = PreActResNet18()
base_optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
sam = SAM(model.parameters(), base_optimizer, rho=0.05)

for batch_idx, (data, target) in enumerate(train_loader):
    # First forward-backward pass
    output = model(data)
    loss = nn.CrossEntropyLoss()(output, target)
    loss.backward()
    
    # First step: perturb weights
    sam.first_step(zero_grad=True)
    
    # Second forward-backward pass (with BatchNorm stats disabled)
    disable_bn_stats(model)
    output = model(data)
    loss = nn.CrossEntropyLoss()(output, target)
    loss.backward()
    enable_bn_stats(model)
    
    # Second step: restore weights and update
    sam.second_step(zero_grad=True)

Data analyst visualizing double descent curve for model generalization Dev Environment Setup

Experimental Results: SAM vs. SGD on Fashion-MNIST

We trained a PreAct ResNet-18 (11.2M parameters) on Fashion-MNIST (60K training images, 10 classes). To ensure fair comparison, SAM runs 150 epochs while SGD runs 300 epochs (since each SAM step requires two backpropagations).

MetricSGD (300 epochs)SAM (150 epochs)SAM (200 epochs)
Test Accuracy92.0%92.5%92.7%
Generalization Gap6.8%2.3%~3.0%

Key observations:

  • SAM trains slower: Training accuracy reaches near-perfect levels later than SGD
  • SAM generalizes better: Test accuracy edge is small (0.5-0.7%), but the generalization gap is dramatically smaller (2.3% vs 6.8%)
  • Flat minima confirmed: The much lower generalization gap directly indicates that SAM finds flatter minima

Limitations and Caveats

  1. Computational cost: SAM requires 2 forward and 2 backward passes per step, effectively doubling training time per epoch compared to standard optimizers.
  2. Hyperparameter sensitivity: The ρ (rho) parameter controls the perturbation radius and needs tuning for each dataset/architecture.
  3. BatchNorm complexity: The running statistics fix adds boilerplate code, and forgetting it can silently degrade performance.
  4. Small datasets shine: SAM's benefits are most pronounced when fine-tuning pre-trained models on small datasets; on large-scale training (e.g., ImageNet from scratch), gains may be marginal.

Next Steps

  • Explore Hessian analysis: Compare the Hessian spectrum of SAM-trained vs. SGD-trained models to quantitatively measure flatness.
  • Try SAM with Adam: The base optimizer can be swapped — SAM(params, Adam(model.parameters(), lr=1e-3)) — often yields state-of-the-art results in NLP fine-tuning.
  • Read the original paper: Foret et al. (2020) "Sharpness-Aware Minimization for Efficiently Improving Generalization" provides theoretical grounding and extensive experiments.

For a practical guide on building culturally-aware AI datasets, see this tutorial on Nemotron-Personas-Brazil.

Cloud infrastructure diagram with AI model training pipeline Development Concept Image

Conclusion

SAM is not a silver bullet, but it's a principled optimizer that directly addresses the generalization gap in overparameterized models. The core idea — minimizing loss while simultaneously seeking flat minima — is both elegant and effective. For practitioners working with small datasets, fine-tuning pre-trained models, or dealing with overfitting, SAM is a valuable addition to your toolkit.

Key takeaway: If your model is overparameterized (most modern deep learning models are) and you struggle with generalization, try SAM. The 2x computational overhead is often worth the improved test performance and reduced overfitting.

Reference: This article is based on the pedagogical review published on Towards Data Science.

This content was drafted using AI tools based on reliable sources, and has been reviewed by our editorial team before publication. It is not intended to replace professional advice.