Por Que Modelos Superdimensionados Não Sempre Sobreajustam?

O deep learning moderno vive de modelos com muito mais parâmetros que amostras de treino — um regime que a teoria clássica de ML diz que deveria levar a sobreajuste catastrófico. Mas modelos como ResNet-18 (11,2 milhões de parâmetros) generalizam muito bem em datasets como Fashion-MNIST (60 mil amostras). Esse paradoxo é explicado pela curva de dupla descida: conforme o tamanho do modelo aumenta, o erro de teste primeiro sobe (limiar de interpolação), depois cai novamente no regime superdimensionado.

O segredo? Nem todos os mínimos são iguais. Dois mínimos com valores de perda idênticos podem ter propriedades de generalização muito diferentes dependendo de sua geometria local. Mínimos planos (vales largos) correlacionam-se com melhor generalização que mínimos agudos (vales estreitos). O otimizador SAM foi projetado para buscar explicitamente mínimos planos, tornando-se uma ferramenta poderosa para melhorar a generalização.

Para uma perspectiva mais ampla sobre como agentes de IA estão redefinindo interações em plataformas, confira esta análise do Agent Lee da Cloudflare.

Python code editor showing SAM optimizer implementation with PyTorch IT Technology Image

Algoritmo SAM: Passo a Passo com Código PyTorch

O SAM modifica o loop de otimização padrão adicionando uma etapa de perturbação adversarial antes de cada atualização de peso. Aqui está o algoritmo central:

  1. Calcule o gradiente nos pesos atuais w
  2. Encontre a perturbação ε que maximiza a perda dentro de uma bola de raio ρ
  3. Calcule o gradiente no ponto perturbado w + ε
  4. Atualize os pesos usando esse gradiente

Classe do Otimizador SAM

import torch
from torch.optim import Optimizer

class SAM(Optimizer):
    """
    Otimizador Sharpness-Aware Minimization.
    Args:
        params: parâmetros do modelo
        base_optimizer: otimizador base (ex: SGD, Adam)
        rho: raio de perturbação (hiperparâmetro)
    """
    def __init__(self, params, base_optimizer, rho=0.05):
        defaults = dict(rho=rho)
        super(SAM, self).__init__(params, defaults)
        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        # Calcula a perturbação adversarial para cada grupo de parâmetros
        for group in self.param_groups:
            scale = group['rho'] / (
                sum(p.grad.norm().item() ** 2 for p in group['params'] if p.grad is not None) ** 0.5 + 1e-12
            )
            for p in group['params']:
                if p.grad is None:
                    continue
                e_w = p.grad * scale  # direção da perturbação
                p.add_(e_w)  # perturba os pesos
                self.state[p]['e_w'] = e_w  # armazena para o segundo passo
        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        # Restaura os pesos originais e aplica a atualização do otimizador base
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                p.sub_(self.state[p]['e_w'])  # restaura pesos originais
        self.base_optimizer.step()  # atualização do otimizador base
        if zero_grad:
            self.zero_grad()

Loop de Treino com Correção para BatchNorm

Camadas BatchNorm atualizam estatísticas correntes durante forward passes. Como o SAM usa dois forward passes por iteração, devemos desabilitar as estatísticas correntes durante o segundo passe para evitar corrompê-las com pesos perturbados.

import torch.nn as nn

def disable_bn_stats(model):
    """Desabilita atualização de estatísticas correntes do BatchNorm durante forward."""
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            module.track_running_stats = False

def enable_bn_stats(model):
    """Reabilita atualização de estatísticas correntes do BatchNorm após passe de perturbação."""
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            module.track_running_stats = True

# Exemplo de loop de treino
model = PreActResNet18()
base_optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
sam = SAM(model.parameters(), base_optimizer, rho=0.05)

for batch_idx, (data, target) in enumerate(train_loader):
    # Primeiro forward-backward
    output = model(data)
    loss = nn.CrossEntropyLoss()(output, target)
    loss.backward()
    
    # Primeiro passo: perturba os pesos
    sam.first_step(zero_grad=True)
    
    # Segundo forward-backward (com estatísticas BatchNorm desabilitadas)
    disable_bn_stats(model)
    output = model(data)
    loss = nn.CrossEntropyLoss()(output, target)
    loss.backward()
    enable_bn_stats(model)
    
    # Segundo passo: restaura pesos e atualiza
    sam.second_step(zero_grad=True)

Data analyst visualizing double descent curve for model generalization Programming Illustration

Resultados Experimentais: SAM vs. SGD no Fashion-MNIST

Treinamos um PreAct ResNet-18 (11,2M parâmetros) no Fashion-MNIST (60K imagens de treino, 10 classes). Para comparação justa, o SAM roda 150 épocas enquanto o SGD roda 300 (já que cada passo SAM requer duas retropropagações).

MétricaSGD (300 épocas)SAM (150 épocas)SAM (200 épocas)
Acurácia no Teste92,0%92,5%92,7%
Gap de Generalização6,8%2,3%~3,0%

Observações principais:

  • SAM treina mais devagar: Acurácia de treino atinge níveis quase perfeitos mais tarde que SGD
  • SAM generaliza melhor: Vantagem na acurácia de teste é pequena (0,5-0,7%), mas o gap de generalização é dramaticamente menor (2,3% vs 6,8%)
  • Mínimos planos confirmados: O gap muito menor indica diretamente que SAM encontra mínimos mais planos

Limitações e Cuidados

  1. Custo computacional: SAM requer 2 forward e 2 backward por passo, dobrando o tempo de treino por época comparado a otimizadores padrão.
  2. Sensibilidade a hiperparâmetros: O parâmetro ρ (rho) controla o raio de perturbação e precisa ser ajustado para cada dataset/arquitetura.
  3. Complexidade do BatchNorm: A correção das estatísticas correntes adiciona código boilerplate, e esquecê-la pode degradar silenciosamente o desempenho.
  4. Datasets pequenos brilham: Os benefícios do SAM são mais pronunciados ao fine-tunar modelos pré-treinados em datasets pequenos; em treino em larga escala (ex: ImageNet do zero), os ganhos podem ser marginais.

Próximos Passos

  • Explore análise de Hessiana: Compare o espectro da Hessiana de modelos treinados com SAM vs. SGD para medir quantitativamente a planura.
  • Teste SAM com Adam: O otimizador base pode ser trocado — SAM(params, Adam(model.parameters(), lr=1e-3)) — frequentemente produz resultados state-of-the-art em fine-tuning de NLP.
  • Leia o artigo original: Foret et al. (2020) "Sharpness-Aware Minimization for Efficiently Improving Generalization" fornece fundamentação teórica e experimentos extensos.

Para um guia prático sobre construção de datasets culturalmente conscientes, veja este tutorial sobre Nemotron-Personas-Brazil.

Cloud infrastructure diagram with AI model training pipeline Coding Session Visual

Conclusão

SAM não é uma bala de prata, mas é um otimizador baseado em princípios que aborda diretamente o gap de generalização em modelos superdimensionados. A ideia central — minimizar a perda enquanto simultaneamente busca mínimos planos — é elegante e eficaz. Para profissionais que trabalham com datasets pequenos, fine-tuning de modelos pré-treinados ou lidam com sobreajuste, SAM é uma adição valiosa ao seu kit de ferramentas.

Mensagem principal: Se seu modelo é superdimensionado (a maioria dos modelos modernos de deep learning é) e você luta com generalização, experimente SAM. O custo computacional 2x geralmente vale a pena pela melhoria no desempenho de teste e redução do sobreajuste.

Referência: Este artigo é baseado na revisão pedagógica publicada no Towards Data Science.

Este conteúdo foi elaborado com o auxílio de ferramentas de IA, com base em fontes confiáveis, e revisado pela nossa equipe editorial antes da publicação. Não substitui o aconselhamento de um profissional especializado.