なぜ汎化が問題なのか?過剰パラメータ化のパラドックス

近年のComputer VisionやNLP分野の爆発的な進歩は、過剰パラメータ化されたモデルに支えられています。モデルが学習データを完全に記憶できるほど多くのパラメータを持つという意味です。実際、このようなモデルは学習精度が99%近くに達し、学習損失はほぼゼロになります。

ところが、古典的な機械学習理論では、パラメータが多すぎると**過学習(Overfitting)**が発生し、汎化性能が低下するとされています。しかし現実は逆です。2018年のBelkin et al.や2019年のNakkiran et al.の研究は、**二重降下(Double Descent)**曲線を発見しました。モデルサイズが大きくなるにつれて汎化性能が最初は悪化し(第一の降下)、特定の閾値を超えると再び改善する(第二の降下)現象です。

このパラドックスの鍵は、**損失関数の地形(Loss Landscape)にあります。過剰パラメータ化されたモデルの損失地形には無数の局所的最小点が存在しますが、これらは類似した損失値を持っていても鋭さ(Sharpness)**で大きな違いを示します。鋭い最小点(狭い谷)は小さな変化でも損失が急激に変動しますが、平坦な最小点(広い谷)は変化に対してロバストです。研究により、平坦な最小点がより良い汎化性能と強い相関があることが示されています。

この問題を正面から解決するオプティマイザが、**Sharpness-Aware Minimization (SAM)**です。Foret et al.(2019)が提案したこの手法は、単に損失を最小化するのではなく、周辺近傍で最も損失が大きい方向(敵対的摂動)を見つけ、その損失を最小化することで、平坦な最小点に収束するよう誘導します。

国内開発環境での適用について: 多くのプロジェクトでは事前学習済みモデルをFine-tuningするケースが増えています。SAMは特に小規模データセットでのFine-tuningに非常に効果的です。データが不足しがちな環境で汎化性能を引き上げる強力なツールとなります。

根拠資料: Towards Data Science 原文

Deep learning overparameterization double descent curve sharp minima concept Coding Session Visual

SAMアルゴリズムの数理的直感とPyTorch実装

核心アイデア:2段階最適化

SAMの動作は1回の反復(Iteration)の中で2段階で構成されます。

  1. 第一段階(敵対的摂動探索): 現在の重み (w_t) から半径 (\rho) 以内の近傍で損失を最大化する方向 (\epsilon(w_t)) を求めます。

    [\epsilon(w_t) = \rho \frac{\nabla L(w_t)}{|\nabla L(w_t)|}]

  2. 第二段階(最適化): 見つけた敵対的摂動点 (w_t + \epsilon(w_t)) で勾配を計算し、その勾配を使って元の重みを更新します。

    [w_{t+1} = w_t - \eta \nabla L(w_t + \epsilon(w_t))]

このシンプルな数式が平坦な最小点を見つける秘訣です。

PyTorch実装:SAMオプティマイザクラス

import torch
from torch.optim import Optimizer

class SAM(Optimizer):
    """
    Sharpness-Aware Minimization (SAM) オプティマイザ
    
    Args:
        params: モデルパラメータ
        base_optimizer: ベースオプティマイザ(SGD, Adam等)
        rho: 摂動半径(ハイパーパラメータ、通常0.05)
    """
    def __init__(self, params, base_optimizer, rho=0.05):
        # SAMはベースオプティマイザをラップ
        self.base_optimizer = base_optimizer
        self.rho = rho
        super().__init__(params, defaults={'rho': rho})
        
    @torch.no_grad()
    def first_step(self, zero_grad=False):
        # 第1段階:敵対的摂動を計算し重みに適用
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                # 摂動方向 = 正規化された勾配
                grad_norm = p.grad.norm()
                if grad_norm == 0:
                    continue
                e_w = self.rho * p.grad / (grad_norm + 1e-12)
                # 重みに摂動を適用(一時保存)
                p.add_(e_w)
        if zero_grad:
            self.zero_grad()
            
    @torch.no_grad()
    def second_step(self, zero_grad=False):
        # 第2段階:元の重みを復元しベースオプティマイザで更新
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                # 摂動除去(元の重みに復元)
                p.sub_(self.rho * p.grad / (p.grad.norm() + 1e-12))
        # ベースオプティマイザのステップ実行
        self.base_optimizer.step()
        if zero_grad:
            self.zero_grad()

学習ループでのSAM使用法

# SAMオプティマイザ生成(SGDをベースに)
base_optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.0, weight_decay=0.0)
sam_optimizer = SAM(model.parameters(), base_optimizer, rho=0.05)

for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        # --- 第一回 forward-backward ---
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        
        # 第1段階:敵対的摂動適用
        sam_optimizer.first_step(zero_grad=True)
        
        # --- 第二回 forward-backward(摂動された重みで) ---
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        
        # 第2段階:重み復元および更新
        sam_optimizer.second_step(zero_grad=True)

⚠️ BatchNorm注意点(実務で最も重要な部分)

モデルにBatchNormレイヤーが含まれる場合、第二回forward passでrunning statisticsが更新されないよう必ず無効化する必要があります。そうしないと、摂動された重みの統計が元のモデルのrunning statisticsを汚染します。

def disable_bn_stats(model):
    """
    BatchNormのrunning statistics更新を無効化
    """
    for module in model.modules():
        if isinstance(module, torch.nn.BatchNorm2d) or isinstance(module, torch.nn.BatchNorm1d):
            module.track_running_stats = False

def enable_bn_stats(model):
    """
    BatchNormのrunning statistics更新を有効化(デフォルト)
    """
    for module in model.modules():
        if isinstance(module, torch.nn.BatchNorm2d) or isinstance(module, torch.nn.BatchNorm1d):
            module.track_running_stats = True

修正後の学習ループ:

for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        # 第一回 forward-backward(通常の統計更新)
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        sam_optimizer.first_step(zero_grad=True)
        
        # 第二回 forward-backward前にrunning statistics無効化
        disable_bn_stats(model)
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        enable_bn_stats(model)  # 次のiterationのために再有効化
        
        sam_optimizer.second_step(zero_grad=True)

SAM optimizer training loop diagram with PyTorch code and BatchNorm handling IT Technology Image

実践実験:Fashion-MNISTでSAMの効果を検証

実験設定

  • データセット: Fashion-MNIST (60,000学習 / 10,000テスト、10クラス、28x28グレースケール)
  • モデル: PreAct ResNet-18(約1,120万パラメータ、パラメータ-サンプル比186:1 → 過剰パラメータ化)
  • 比較: SGD (lr=0.05, momentum=0, weight_decay=0) vs SAM (rho=0.05, 同一SGD base)
  • 公平比較: SAMは1 epochに2回のbackwardを実行するため、SGDは2 epochを1 standardized epochとして扱う
  • 実験期間: 150 standardized epochs (SAM 150 epoch, SGD 300 epoch)

結果概要

指標SGD (150 std epochs)SAM (150 std epochs)SAM (追加50 epochs)
テスト精度92.0%92.5%92.7%
汎化ギャップ (学習-テスト)6.8%2.3%約3%

核心インサイト

  1. テスト精度の差は小さいが、汎化ギャップの差は大きい。 SAMが6.8%→2.3%と汎化ギャップを約3倍削減しました。これはモデルが学習データに過適合していないことを意味します。
  2. SAMはゆっくり学習するが、より平坦な最小点に到達する。 最初の80 epochまではSGDよりテスト精度が低いものの、その後逆転し、より良い汎化を示します。
  3. 追加学習後も汎化ギャップはSGDより大幅に低いまま維持されます。

注意点: SAMはすべての状況でSGD/Adamより優れているわけではありません。特にデータが十分多い場合やモデルが小さい場合、効果は限定的です。また、2倍の計算コスト(forward-backward各2回)が必要なため、学習時間が重要なプロジェクトではトレードオフを考慮する必要があります。

SAMの限界と追加考慮事項

  1. 計算コスト: 1 iterationあたり2回のforward-backwardが必要で、学習時間が約2倍に増加します。GPUメモリ使用量も増加します。
  2. ハイパーパラメータ (\rho): 適切な (\rho) 値を見つけることが重要です。小さすぎると効果がなく、大きすぎると学習が不安定になります。一般的には0.01〜0.1の範囲でチューニングします。
  3. バッチサイズ: SAMはバッチ統計に敏感なため、小さなバッチサイズでは効果が減少する可能性があります。
  4. 事前学習モデルのFine-tuning: SAMは特に事前学習済みモデルを小規模データセットにFine-tuningする際に最も効果を発揮します。Foret et al.(2019)の元論文でもこの点が強調されています。

Netflix Metaflow Spin platform for rapid ML AI development iteration Algorithm Concept Visual

まとめ:実務にSAMを導入するには

SAMは単なるオプティマイザではなく、汎化性能を最適化する方法論です。国内の開発環境でも以下のシナリオで強力なツールとなります。

SAM導入を推奨するケース:

  • 小規模データセットで事前学習モデルをFine-tuningする場合
  • プロダクションモデルの汎化性能が重要な場合(例:金融、医療、自動運転)
  • 過学習の問題で悩んでいる場合
  • 学習時間よりもモデル品質が重要なプロジェクト

次のステップとしての学習方向:

  1. SAMの理論的背景をより深く理解するには、Foret et al.(2019)の原著論文を読むことをお勧めします。
  2. SAMの派生である**ESAM (Efficient SAM)ASAM (Adaptive SAM)**を調査してみてください。計算コストの削減や性能向上を図ったバージョンです。
  3. Hessianスペクトル分析により、SAMが実際により平坦な最小点に到達しているかを確認する研究も興味深いテーマです。

SAMは「万能ツール」ではありませんが、汎化が重要な問題では必ず検討すべき選択肢です。特に最近ではLLMのFine-tuningにもSAM系オプティマイザが適用される事例が増えており、今後さらに注目される技術です。


合わせて読みたい記事

本コンテンツは、信頼性の高い情報源をもとにAIツールを活用して作成され、編集者によるレビューを経て公開されています。専門家によるアドバイスの代替となるものではありません。