Por Que Modelos Superdimensionados Não Sempre Sobreajustam?
O deep learning moderno vive de modelos com muito mais parâmetros que amostras de treino — um regime que a teoria clássica de ML diz que deveria levar a sobreajuste catastrófico. Mas modelos como ResNet-18 (11,2 milhões de parâmetros) generalizam muito bem em datasets como Fashion-MNIST (60 mil amostras). Esse paradoxo é explicado pela curva de dupla descida: conforme o tamanho do modelo aumenta, o erro de teste primeiro sobe (limiar de interpolação), depois cai novamente no regime superdimensionado.
O segredo? Nem todos os mínimos são iguais. Dois mínimos com valores de perda idênticos podem ter propriedades de generalização muito diferentes dependendo de sua geometria local. Mínimos planos (vales largos) correlacionam-se com melhor generalização que mínimos agudos (vales estreitos). O otimizador SAM foi projetado para buscar explicitamente mínimos planos, tornando-se uma ferramenta poderosa para melhorar a generalização.
Para uma perspectiva mais ampla sobre como agentes de IA estão redefinindo interações em plataformas, confira esta análise do Agent Lee da Cloudflare.

Algoritmo SAM: Passo a Passo com Código PyTorch
O SAM modifica o loop de otimização padrão adicionando uma etapa de perturbação adversarial antes de cada atualização de peso. Aqui está o algoritmo central:
- Calcule o gradiente nos pesos atuais
w - Encontre a perturbação
εque maximiza a perda dentro de uma bola de raioρ - Calcule o gradiente no ponto perturbado
w + ε - Atualize os pesos usando esse gradiente
Classe do Otimizador SAM
import torch
from torch.optim import Optimizer
class SAM(Optimizer):
"""
Otimizador Sharpness-Aware Minimization.
Args:
params: parâmetros do modelo
base_optimizer: otimizador base (ex: SGD, Adam)
rho: raio de perturbação (hiperparâmetro)
"""
def __init__(self, params, base_optimizer, rho=0.05):
defaults = dict(rho=rho)
super(SAM, self).__init__(params, defaults)
self.base_optimizer = base_optimizer
self.param_groups = self.base_optimizer.param_groups
@torch.no_grad()
def first_step(self, zero_grad=False):
# Calcula a perturbação adversarial para cada grupo de parâmetros
for group in self.param_groups:
scale = group['rho'] / (
sum(p.grad.norm().item() ** 2 for p in group['params'] if p.grad is not None) ** 0.5 + 1e-12
)
for p in group['params']:
if p.grad is None:
continue
e_w = p.grad * scale # direção da perturbação
p.add_(e_w) # perturba os pesos
self.state[p]['e_w'] = e_w # armazena para o segundo passo
if zero_grad:
self.zero_grad()
@torch.no_grad()
def second_step(self, zero_grad=False):
# Restaura os pesos originais e aplica a atualização do otimizador base
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.sub_(self.state[p]['e_w']) # restaura pesos originais
self.base_optimizer.step() # atualização do otimizador base
if zero_grad:
self.zero_grad()
Loop de Treino com Correção para BatchNorm
Camadas BatchNorm atualizam estatísticas correntes durante forward passes. Como o SAM usa dois forward passes por iteração, devemos desabilitar as estatísticas correntes durante o segundo passe para evitar corrompê-las com pesos perturbados.
import torch.nn as nn
def disable_bn_stats(model):
"""Desabilita atualização de estatísticas correntes do BatchNorm durante forward."""
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
module.track_running_stats = False
def enable_bn_stats(model):
"""Reabilita atualização de estatísticas correntes do BatchNorm após passe de perturbação."""
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
module.track_running_stats = True
# Exemplo de loop de treino
model = PreActResNet18()
base_optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
sam = SAM(model.parameters(), base_optimizer, rho=0.05)
for batch_idx, (data, target) in enumerate(train_loader):
# Primeiro forward-backward
output = model(data)
loss = nn.CrossEntropyLoss()(output, target)
loss.backward()
# Primeiro passo: perturba os pesos
sam.first_step(zero_grad=True)
# Segundo forward-backward (com estatísticas BatchNorm desabilitadas)
disable_bn_stats(model)
output = model(data)
loss = nn.CrossEntropyLoss()(output, target)
loss.backward()
enable_bn_stats(model)
# Segundo passo: restaura pesos e atualiza
sam.second_step(zero_grad=True)

Resultados Experimentais: SAM vs. SGD no Fashion-MNIST
Treinamos um PreAct ResNet-18 (11,2M parâmetros) no Fashion-MNIST (60K imagens de treino, 10 classes). Para comparação justa, o SAM roda 150 épocas enquanto o SGD roda 300 (já que cada passo SAM requer duas retropropagações).
| Métrica | SGD (300 épocas) | SAM (150 épocas) | SAM (200 épocas) |
|---|---|---|---|
| Acurácia no Teste | 92,0% | 92,5% | 92,7% |
| Gap de Generalização | 6,8% | 2,3% | ~3,0% |
Observações principais:
- SAM treina mais devagar: Acurácia de treino atinge níveis quase perfeitos mais tarde que SGD
- SAM generaliza melhor: Vantagem na acurácia de teste é pequena (0,5-0,7%), mas o gap de generalização é dramaticamente menor (2,3% vs 6,8%)
- Mínimos planos confirmados: O gap muito menor indica diretamente que SAM encontra mínimos mais planos
Limitações e Cuidados
- Custo computacional: SAM requer 2 forward e 2 backward por passo, dobrando o tempo de treino por época comparado a otimizadores padrão.
- Sensibilidade a hiperparâmetros: O parâmetro
ρ(rho) controla o raio de perturbação e precisa ser ajustado para cada dataset/arquitetura. - Complexidade do BatchNorm: A correção das estatísticas correntes adiciona código boilerplate, e esquecê-la pode degradar silenciosamente o desempenho.
- Datasets pequenos brilham: Os benefícios do SAM são mais pronunciados ao fine-tunar modelos pré-treinados em datasets pequenos; em treino em larga escala (ex: ImageNet do zero), os ganhos podem ser marginais.
Próximos Passos
- Explore análise de Hessiana: Compare o espectro da Hessiana de modelos treinados com SAM vs. SGD para medir quantitativamente a planura.
- Teste SAM com Adam: O otimizador base pode ser trocado —
SAM(params, Adam(model.parameters(), lr=1e-3))— frequentemente produz resultados state-of-the-art em fine-tuning de NLP. - Leia o artigo original: Foret et al. (2020) "Sharpness-Aware Minimization for Efficiently Improving Generalization" fornece fundamentação teórica e experimentos extensos.
Para um guia prático sobre construção de datasets culturalmente conscientes, veja este tutorial sobre Nemotron-Personas-Brazil.

Conclusão
SAM não é uma bala de prata, mas é um otimizador baseado em princípios que aborda diretamente o gap de generalização em modelos superdimensionados. A ideia central — minimizar a perda enquanto simultaneamente busca mínimos planos — é elegante e eficaz. Para profissionais que trabalham com datasets pequenos, fine-tuning de modelos pré-treinados ou lidam com sobreajuste, SAM é uma adição valiosa ao seu kit de ferramentas.
Mensagem principal: Se seu modelo é superdimensionado (a maioria dos modelos modernos de deep learning é) e você luta com generalização, experimente SAM. O custo computacional 2x geralmente vale a pena pela melhoria no desempenho de teste e redução do sobreajuste.
Referência: Este artigo é baseado na revisão pedagógica publicada no Towards Data Science.