¿Por Qué los Modelos Sobredimensionados No Siempre Sobreajustan?

El deep learning moderno vive de modelos con muchos más parámetros que muestras de entrenamiento — un régimen que la teoría clásica de ML dice que debería llevar a sobreajuste catastrófico. Pero modelos como ResNet-18 (11.2 millones de parámetros) generalizan notablemente bien en datasets como Fashion-MNIST (60 mil muestras). Esta paradoja se explica por la curva de doble descenso: a medida que el tamaño del modelo aumenta, el error de prueba primero sube (umbral de interpolación), luego vuelve a caer en el régimen sobredimensionado.

El truco: no todos los mínimos son iguales. Dos mínimos con valores de pérdida idénticos pueden tener propiedades de generalización muy diferentes según su geometría local. Mínimos planos (valles anchos) se correlacionan con mejor generalización que mínimos agudos (valles estrechos). El optimizador SAM está diseñado para buscar explícitamente mínimos planos, convirtiéndose en una herramienta poderosa para mejorar la generalización.

Para una perspectiva más amplia sobre cómo los agentes de IA están redefiniendo las interacciones en plataformas, checa este análisis del Agent Lee de Cloudflare.

Python code editor showing SAM optimizer implementation with PyTorch Technical Structure Concept

Algoritmo SAM: Paso a Paso con Código PyTorch

SAM modifica el bucle de optimización estándar agregando un paso de perturbación adversarial antes de cada actualización de peso. Aquí está el algoritmo central:

  1. Calcula el gradiente en los pesos actuales w
  2. Encuentra la perturbación ε que maximiza la pérdida dentro de una bola de radio ρ
  3. Calcula el gradiente en el punto perturbado w + ε
  4. Actualiza los pesos usando ese gradiente

Clase del Optimizador SAM

import torch
from torch.optim import Optimizer

class SAM(Optimizer):
    """
    Optimizador Sharpness-Aware Minimization.
    Args:
        params: parámetros del modelo
        base_optimizer: optimizador base (ej: SGD, Adam)
        rho: radio de perturbación (hiperparámetro)
    """
    def __init__(self, params, base_optimizer, rho=0.05):
        defaults = dict(rho=rho)
        super(SAM, self).__init__(params, defaults)
        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        # Calcula la perturbación adversarial para cada grupo de parámetros
        for group in self.param_groups:
            scale = group['rho'] / (
                sum(p.grad.norm().item() ** 2 for p in group['params'] if p.grad is not None) ** 0.5 + 1e-12
            )
            for p in group['params']:
                if p.grad is None:
                    continue
                e_w = p.grad * scale  # dirección de la perturbación
                p.add_(e_w)  # perturba los pesos
                self.state[p]['e_w'] = e_w  # almacena para el segundo paso
        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        # Restaura los pesos originales y aplica la actualización del optimizador base
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                p.sub_(self.state[p]['e_w'])  # restaura pesos originales
        self.base_optimizer.step()  # actualización del optimizador base
        if zero_grad:
            self.zero_grad()

Bucle de Entrenamiento con Corrección para BatchNorm

Las capas BatchNorm actualizan estadísticas corrientes durante forward passes. Como SAM usa dos forward passes por iteración, debemos deshabilitar las estadísticas corrientes durante el segundo pase para evitar corromperlas con pesos perturbados.

import torch.nn as nn

def disable_bn_stats(model):
    """Deshabilita la actualización de estadísticas corrientes de BatchNorm durante forward."""
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            module.track_running_stats = False

def enable_bn_stats(model):
    """Rehabilita la actualización de estadísticas corrientes de BatchNorm después del pase de perturbación."""
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            module.track_running_stats = True

# Ejemplo de bucle de entrenamiento
model = PreActResNet18()
base_optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
sam = SAM(model.parameters(), base_optimizer, rho=0.05)

for batch_idx, (data, target) in enumerate(train_loader):
    # Primer forward-backward
    output = model(data)
    loss = nn.CrossEntropyLoss()(output, target)
    loss.backward()
    
    # Primer paso: perturba los pesos
    sam.first_step(zero_grad=True)
    
    # Segundo forward-backward (con estadísticas BatchNorm deshabilitadas)
    disable_bn_stats(model)
    output = model(data)
    loss = nn.CrossEntropyLoss()(output, target)
    loss.backward()
    enable_bn_stats(model)
    
    # Segundo paso: restaura pesos y actualiza
    sam.second_step(zero_grad=True)

Data analyst visualizing double descent curve for model generalization IT Technology Image

Resultados Experimentales: SAM vs. SGD en Fashion-MNIST

Entrenamos un PreAct ResNet-18 (11.2M parámetros) en Fashion-MNIST (60K imágenes de entrenamiento, 10 clases). Para comparación justa, SAM corre 150 épocas mientras que SGD corre 300 (ya que cada paso SAM requiere dos retropropagaciones).

MétricaSGD (300 épocas)SAM (150 épocas)SAM (200 épocas)
Precisión en Prueba92.0%92.5%92.7%
Brecha de Generalización6.8%2.3%~3.0%

Observaciones clave:

  • SAM entrena más lento: La precisión de entrenamiento alcanza niveles casi perfectos más tarde que SGD
  • SAM generaliza mejor: La ventaja en precisión de prueba es pequeña (0.5-0.7%), pero la brecha de generalización es dramáticamente menor (2.3% vs 6.8%)
  • Mínimos planos confirmados: La brecha mucho menor indica directamente que SAM encuentra mínimos más planos

Limitaciones y Precauciones

  1. Costo computacional: SAM requiere 2 forward y 2 backward por paso, duplicando el tiempo de entrenamiento por época comparado con optimizadores estándar.
  2. Sensibilidad a hiperparámetros: El parámetro ρ (rho) controla el radio de perturbación y necesita ajustarse para cada dataset/arquitectura.
  3. Complejidad de BatchNorm: La corrección de estadísticas corrientes añade código boilerplate, y olvidarla puede degradar silenciosamente el rendimiento.
  4. Datasets pequeños brillan: Los beneficios de SAM son más pronunciados al hacer fine-tune de modelos preentrenados en datasets pequeños; en entrenamiento a gran escala (ej: ImageNet desde cero), las ganancias pueden ser marginales.

Próximos Pasos

  • Explora análisis de Hessiana: Compara el espectro de la Hessiana de modelos entrenados con SAM vs. SGD para medir cuantitativamente la planicie.
  • Prueba SAM con Adam: El optimizador base se puede cambiar — SAM(params, Adam(model.parameters(), lr=1e-3)) — a menudo produce resultados state-of-the-art en fine-tuning de NLP.
  • Lee el artículo original: Foret et al. (2020) "Sharpness-Aware Minimization for Efficiently Improving Generalization" proporciona fundamentación teórica y experimentos extensos.

Para una guía práctica sobre construcción de datasets culturalmente conscientes, mira este tutorial sobre Nemotron-Personas-Brazil.

Cloud infrastructure diagram with AI model training pipeline Programming Illustration

Conclusión

SAM no es una bala de plata, pero es un optimizador basado en principios que aborda directamente la brecha de generalización en modelos sobredimensionados. La idea central — minimizar la pérdida mientras simultáneamente se buscan mínimos planos — es elegante y efectiva. Para profesionales que trabajan con datasets pequeños, fine-tuning de modelos preentrenados o lidian con sobreajuste, SAM es una adición valiosa a tu kit de herramientas.

Mensaje clave: Si tu modelo está sobredimensionado (la mayoría de los modelos modernos de deep learning lo están) y luchas con la generalización, prueba SAM. El costo computacional 2x generalmente vale la pena por la mejora en el rendimiento de prueba y la reducción del sobreajuste.

Referencia: Este artículo está basado en la revisión pedagógica publicada en Towards Data Science.

Este contenido fue redactado con la asistencia de herramientas de IA, basándose en fuentes confiables, y fue revisado por nuestro equipo editorial antes de su publicación. No reemplaza el asesoramiento de un profesional especializado.