왜 일반화가 문제인가? 과대매개변수화(Overparameterization)의 역설

최근 Computer Vision과 NLP 분야의 폭발적인 성장은 과대매개변수화된 모델에 기반합니다. 모델이 학습 데이터를 완전히 암기할 수 있을 정도로 많은 파라미터를 가진다는 뜻이죠. 실제로 이런 모델은 학습 정확도가 99%에 육박하고, 학습 손실은 거의 0에 가깝습니다.

그런데 고전적인 머신러닝 이론은 이렇게 파라미터가 많으면 **과대적합(Overfitting)**이 일어나 일반화 성능이 나빠진다고 말합니다. 그런데 실제로는 정반대입니다. 2018년 Belkin et al.과 2019년 Nakkiran et al.의 연구는 이중 하강(Double Descent) 곡선을 발견했습니다. 모델 크기가 커질수록 일반화 성능이 먼저 나빠졌다가(첫 번째 하강) 특정 임계점을 넘으면 다시 좋아진다는(두 번째 하강) 현상입니다.

이 역설의 핵심은 **손실 함수의 지형(Loss Landscape)**에 있습니다. 과대매개변수화된 모델의 손실 지형에는 수많은 지역 최소점이 존재하는데, 이들은 비슷한 손실 값을 가지더라도 **날카로움(Sharpness)**에서 큰 차이를 보입니다. 날카로운 최소점(좁은 골짜기)은 작은 변화에도 손실이 급격히 변하지만, 평평한 최소점(넓은 골짜기)은 변화에 강건합니다. 연구들은 평평한 최소점이 더 나은 일반화 성능과 강한 상관관계가 있음을 보여줍니다.

이 문제를 정면으로 해결하는 옵티마이저가 바로 **Sharpness-Aware Minimization (SAM)**입니다. Foret et al.(2019)이 제안한 이 방법은 단순히 손실을 최소화하는 대신, 주변 이웃에서 가장 손실이 큰 방향(적대적 섭동)을 찾아 그 손실을 최소화함으로써 평평한 최소점으로 수렴하도록 유도합니다.

국내 SI/프로젝트 환경 팁: 대부분의 국내 프로젝트는 사전 학습된 모델을 Fine-tuning하는 방식이 많습니다. SAM은 특히 소규모 데이터셋으로 Fine-tuning할 때 매우 효과적입니다. 데이터가 부족한 환경에서 일반화 성능을 끌어올리는 강력한 도구가 될 수 있어요.

근거자료: Towards Data Science 원문

Deep learning overparameterization double descent curve sharp minima concept Technical Structure Concept

SAM 알고리즘의 수학적 직관과 PyTorch 구현

핵심 아이디어: 두 단계 최적화

SAM의 동작은 한 번의 반복(Iteration) 안에서 두 단계로 이루어집니다.

  1. 첫 번째 단계 (적대적 섭동 탐색): 현재 가중치 (w_t)에서 반경 (\rho) 이내의 이웃 중 손실을 최대화하는 방향 (\epsilon(w_t))을 찾습니다.

    [\epsilon(w_t) = \rho \frac{\nabla L(w_t)}{|\nabla L(w_t)|}]

  2. 두 번째 단계 (최적화): 찾은 적대적 섭동 지점 (w_t + \epsilon(w_t))에서 기울기를 계산하고, 그 기울기를 사용해 원래 가중치를 업데이트합니다.

    [w_{t+1} = w_t - \eta \nabla L(w_t + \epsilon(w_t))]

이 간단한 수식 하나가 평평한 최소점을 찾는 비결입니다.

PyTorch 구현: SAM 옵티마이저 클래스

import torch
from torch.optim import Optimizer

class SAM(Optimizer):
    """
    Sharpness-Aware Minimization (SAM) 옵티마이저
    
    Args:
        params: 모델 파라미터
        base_optimizer: 기본 옵티마이저 (SGD, Adam 등)
        rho: 섭동 반경 (하이퍼파라미터, 일반적으로 0.05)
    """
    def __init__(self, params, base_optimizer, rho=0.05):
        # SAM은 기본 옵티마이저를 래핑
        self.base_optimizer = base_optimizer
        self.rho = rho
        super().__init__(params, defaults={'rho': rho})
        
    @torch.no_grad()
    def first_step(self, zero_grad=False):
        # 1단계: 적대적 섭동 계산 후 가중치에 적용
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                # 섭동 방향 = 정규화된 그래디언트
                grad_norm = p.grad.norm()
                if grad_norm == 0:
                    continue
                e_w = self.rho * p.grad / (grad_norm + 1e-12)
                # 가중치에 섭동 적용 (임시 저장)
                p.add_(e_w)
        if zero_grad:
            self.zero_grad()
            
    @torch.no_grad()
    def second_step(self, zero_grad=False):
        # 2단계: 원래 가중치 복원 후 기본 옵티마이저로 업데이트
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                # 섭동 제거 (원래 가중치로 복원)
                p.sub_(self.rho * p.grad / (p.grad.norm() + 1e-12))
        # 기본 옵티마이저 스텝 수행
        self.base_optimizer.step()
        if zero_grad:
            self.zero_grad()

학습 루프에서 SAM 사용하기

# SAM 옵티마이저 생성 (SGD를 기본으로)
base_optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.0, weight_decay=0.0)
sam_optimizer = SAM(model.parameters(), base_optimizer, rho=0.05)

for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        # --- 첫 번째 forward-backward ---
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        
        # 1단계: 적대적 섭동 적용
        sam_optimizer.first_step(zero_grad=True)
        
        # --- 두 번째 forward-backward (섭동된 가중치로) ---
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        
        # 2단계: 가중치 복원 및 업데이트
        sam_optimizer.second_step(zero_grad=True)

⚠️ BatchNorm 주의사항 (실무에서 가장 중요한 부분)

모델에 BatchNorm 레이어가 포함된 경우, 두 번째 forward pass에서 running statistics가 업데이트되지 않도록 반드시 비활성화해야 합니다. 그렇지 않으면 섭동된 가중치의 통계가 원래 모델의 running statistics를 오염시킵니다.

def disable_bn_stats(model):
    """
    BatchNorm의 running statistics 업데이트를 비활성화
    """
    for module in model.modules():
        if isinstance(module, torch.nn.BatchNorm2d) or isinstance(module, torch.nn.BatchNorm1d):
            module.track_running_stats = False

def enable_bn_stats(model):
    """
    BatchNorm의 running statistics 업데이트를 활성화 (기본값)
    """
    for module in model.modules():
        if isinstance(module, torch.nn.BatchNorm2d) or isinstance(module, torch.nn.BatchNorm1d):
            module.track_running_stats = True

수정된 학습 루프:

for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        # 첫 번째 forward-backward (정상 통계 업데이트)
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        sam_optimizer.first_step(zero_grad=True)
        
        # 두 번째 forward-backward 전에 running statistics 비활성화
        disable_bn_stats(model)
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        enable_bn_stats(model)  # 다음 iteration을 위해 다시 활성화
        
        sam_optimizer.second_step(zero_grad=True)

SAM optimizer training loop diagram with PyTorch code and BatchNorm handling Software Concept Art

실전 실험: Fashion-MNIST로 SAM 효과 검증

실험 설정

  • 데이터셋: Fashion-MNIST (60,000 학습 / 10,000 테스트, 10개 클래스, 28x28 흑백)
  • 모델: PreAct ResNet-18 (약 1,120만 파라미터, 파라미터-샘플 비율 186:1 → 과대매개변수화)
  • 비교: SGD (lr=0.05, momentum=0, weight_decay=0) vs SAM (rho=0.05, 동일 SGD base)
  • 공정 비교: SAM은 1 epoch에 2번의 backward를 수행하므로, SGD는 2 epoch를 1 standardized epoch로 간주
  • 실험 기간: 150 standardized epochs (SAM 150 epoch, SGD 300 epoch)

결과 요약

지표SGD (150 std epochs)SAM (150 std epochs)SAM (추가 50 epochs)
테스트 정확도92.0%92.5%92.7%
일반화 갭 (학습-테스트)6.8%2.3%약 3%

핵심 인사이트

  1. 테스트 정확도 차이는 작지만, 일반화 갭 차이는 크다. SAM이 6.8% → 2.3%로 일반화 갭을 3배 가까이 줄였습니다. 이는 모델이 학습 데이터에 덜 과적합되었음을 의미합니다.
  2. SAM은 느리게 학습하지만, 더 평평한 최소점에 도달합니다. 초반 80 epoch까지는 SGD보다 테스트 정확도가 낮지만, 이후 역전하며 더 나은 일반화를 보여줍니다.
  3. 추가 학습에도 일반화 갭이 SGD보다 훨씬 낮게 유지됩니다.

주의사항: SAM은 모든 상황에서 SGD/Adam보다 항상 좋은 것은 아닙니다. 특히 데이터가 충분히 많거나 모델이 작은 경우 효과가 미미할 수 있습니다. 또한 2배의 연산 비용(forward-backward 2회)이 필요하므로, 학습 시간이 중요한 프로젝트에서는 trade-off를 고려해야 합니다.

SAM의 한계와 추가 고려사항

  1. 연산 비용: 1 iteration당 2번의 forward-backward가 필요하므로 학습 시간이 약 2배 증가합니다. GPU 메모리 사용량도 증가합니다.
  2. 하이퍼파라미터 (\rho): 적절한 (\rho) 값을 찾는 것이 중요합니다. 너무 작으면 효과가 없고, 너무 크면 학습이 불안정해집니다. 일반적으로 0.01~0.1 사이에서 튜닝합니다.
  3. 배치 크기: SAM은 배치 통계에 민감하므로, 작은 배치 크기에서는 효과가 줄어들 수 있습니다.
  4. 사전 학습 모델 Fine-tuning: SAM은 특히 사전 학습된 모델을 소규모 데이터셋에 Fine-tuning할 때 가장 큰 효과를 발휘합니다. Foret et al.(2019)의 원 논문에서도 이 부분을 강조합니다.

Netflix Metaflow Spin platform for rapid ML AI development iteration Programming Illustration

결론: 실무에 SAM 도입하기

SAM은 단순한 옵티마이저가 아니라 일반화 성능을 최적화하는 방법론입니다. 국내 개발 환경에서도 다음과 같은 시나리오에서 강력한 도구가 될 수 있습니다.

SAM 도입을 추천하는 경우:

  • 소규모 데이터셋으로 사전 학습 모델을 Fine-tuning할 때
  • 프로덕션 모델의 일반화 성능이 중요할 때 (예: 금융, 의료, 자율주행)
  • 과대적합 문제로 고민 중일 때
  • 학습 시간보다 모델 품질이 더 중요한 프로젝트

다음 단계 학습 방향:

  1. SAM의 이론적 배경을 더 깊이 이해하려면 Foret et al.(2019) 원 논문을 읽어보세요.
  2. SAM의 변형인 **ESAM (Efficient SAM)**이나 **ASAM (Adaptive SAM)**을 살펴보세요. 연산 비용을 줄이거나 성능을 더 개선한 버전입니다.
  3. Hessian 스펙트럼 분석을 통해 SAM이 실제로 더 평평한 최소점에 도달하는지 확인하는 연구도 흥미롭습니다.

SAM은 '만능 도구'는 아니지만, 일반화가 중요한 문제에서는 반드시 고려해야 할 옵션입니다. 특히 최근 LLM Fine-tuning에도 SAM 계열의 옵티마이저가 적용되는 사례가 늘고 있어, 앞으로 더 주목받을 기술입니다.


함께 보면 좋은 글

본 콘텐츠는 신뢰할 수 있는 출처를 바탕으로 AI 도구를 활용하여 초안이 작성되었으며, 편집자의 검토를 거쳐 발행되었습니다. 전문가의 조언을 대체하지 않습니다.