月影

日々の雑感

【PyTorch】WideResNetのコードを完全理解する穴埋めドリル|実装と理論の演習

 

【深層学習ドリル】WideResNetの実装とShape計算

「コードをコピペすれば動くけれど、中身の処理がイメージできない…」

深層学習の勉強中に、そんな壁にぶつかったことはありませんか?
PyTorchなどのフレームワークで最も躓きやすいのは、構文エラーではなくテンソル(行列)の形状変化」の理解です。

今回は、画像認識モデルの決定版である WideResNet を題材に、あえてコードを「穴埋め」形式にしました。
単に読むだけでなく、クイズに答えるつもりで「データの次元(Shape)」を予測してみてください。
これが解ければ、CNNの設計図が頭の中に浮かぶようになります。

第1問:WideBasicBlockの実装

ResNetの基本ブロックにおける「畳み込み処理」と「ショートカットの合流」を完成させてください。 以下のコードの 問X の部分に入る正しいコードは何でしょうか?

class WideBasicBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super(WideBasicBlock, self).__init__()
        
        # 1つ目の畳み込み
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, ...)
        self.bn1 = nn.BatchNorm2d(planes)
        
        # 2つ目の畳み込み
        # 【問1】入力chと出力chはどうなる?
        self.conv2 = nn.Conv2d(【問1】, kernel_size=3, stride=1, ...)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = ... # 次元合わせの実装

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        # 【問2】Residual Connection(残差接続)の実装
        # ここで何と何を足し合わせるか?
        out += 【問2】
        
        out = F.relu(out)
        return out
答えと解説を見る

【問1】 planes, planes

解説: ResNetのBasicBlock内では、基本的にチャンネル数は変化させません。1層目(conv1)で既に目標のチャンネル数(planes)になっているため、2層目(conv2)もそれを受け取り、そのまま出力します。


【問2】 self.shortcut(x)

解説: これがResNetの核心である「スキップ接続」です。畳み込みを通って特徴抽出された out に、元の入力 x を加算します。
※次元が異なる場合に備えて、生の x ではなく self.shortcut(x) と記述するのが定石です。

第2問:テンソルのShape予測

モデルの中でデータがどう変形していくかを追跡します。
条件:widen_factor (k) = 2 とします。

# 条件: k=2, 入力画像: [Batch=1, RGB=3, 32x32]

# --- 最初の畳み込み ---
# Conv2d(3, 16) -> 出力: [1, 16, 32, 32]

# --- Stage 1 (stride=1) ---
# WideBasicBlock(in=16, out=16*k, stride=1)
# 【問3】出力Shapeは?
# 答え: [1, 【問3】, 32, 32]

# --- Stage 2 (stride=2) ---
# WideBasicBlock(in=前層の出力, out=32*k, stride=2)
# 【問4】出力Shapeは?
# 答え: [1, 【問4 A】, 【問4 B】, 【問4 B】]
答えと解説を見る

【問3】 32

解説: ベースの16チャンネルに k=2 を掛けるため、16 × 2 = 32 となります。stride=1 なので画像サイズ(32x32)は変わりません。


【問4 A】 64
【問4 B】 16

解説: チャンネル数:ベース32 × k(2) = 64
画像サイズ:stride=2 なので、32 ÷ 2 = 16 に半減します。
最終的なShapeは [1, 64, 16, 16] です。

第3問:WideResNetの理論と目的(基礎知識)

コードを書く前に、「なぜWideResNetを使うのか?」を知っておくことが重要です。
以下の文章の空欄を埋めて、WideResNetの核心を理解しましょう。

Q1. ResNetの弱点とWideResNetの解決策

オリジナルのResNetは、精度を上げるために層を非常に「深く(Deep)」しました(例:1000層)。
しかし、層があまりに深すぎると、学習にかかる時間が長くなる上、勾配消失などの問題で効果が薄れる(Diminishing returns)ことがわかりました。

そこでWideResNetでは、層を増やす代わりに、各層の【A】を増やす(=畳み込みフィルタの数を増やす)アプローチを取りました。
これにより、同じ精度のモデルをより浅い層で実現しました。

Q2. 計算速度のメリット

深いモデル(Deep)よりも広いモデル(Wide)の方が、GPUでの学習速度が速くなる傾向があります。
これは、GPUが「逐次的な計算(前の層が終わるのを待つ)」よりも、大きな行列を一度に計算する【B】化」が得意だからです。

Q3. 追加された正則化技術

WideResNetでは、パラメータ数が増えることによる過学習(Overfitting)を防ぐため、畳み込み層の間に【C】を入れることが有効であると報告されています。
(※オリジナルのResNetではあまり使われませんが、WideResNetでは効果的です)

答えと解説を見る

【A】 チャンネル数(幅 / Width)

解説: モデルの「幅(Width)」とは、各層におけるチャンネル数(特徴マップの枚数)のことです。WideResNetは、係数 k (widen factor) を使ってこの幅を広げます。


【B】 並列(Parallel)

解説: GPUは「一度に大量の計算をする(並列計算)」のが得意です。層を深くすると「前の層の結果待ち」が何度も発生しますが、幅を広げると1回の計算量が増え、GPUの能力をフル活用できるため、結果的に学習が速くなります。


【C】 Dropout(ドロップアウト

解説: WideResNetの論文では、層の間にDropoutを入れることで、過学習を抑えて精度が向上したと報告されています。これがコード内の conv1conv2 の間に入ることがあります。

第4問:学習ループの「黄金の3ステップ」

モデルを動かすための「学習ループ」にも、呪文のように唱えるべき手順があります。
勾配(gradient)の計算とパラメータ更新の流れを完成させてください。

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train() # 学習モード

    for inputs, targets in dataloader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        # 【問1】勾配の初期化
        # PyTorchは勾配を累積するため、ループの最初にリセットが必要です。
        optimizer.【問1】()

        # 2. 順伝播 (Forward)
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # 【問2】逆伝播 (Backward)
        # 誤差から勾配を計算します。
        loss.【問2】()

        # 【問3】パラメータ更新
        # 計算された勾配を使って重みを更新します。
        optimizer.【問3】()
        
    return loss.item()
答えと解説を見る

【問1】 zero_grad

解説: optimizer.zero_grad()
イテレーションの最初に勾配をゼロにリセットします。


【問2】 backward

解説: loss.backward()
誤差逆伝播を行い、勾配を計算します。


【問3】 step

解説: optimizer.step()
計算された勾配を使ってパラメータを更新します。

第5問:推論(評価)モードの作法

評価時(Validation)の無駄な計算を防ぎ、正しい結果を得るためのコードを埋めてください。

def validate(model, dataloader, device):
    # 【問4】モデルを評価モードにする
    # (Dropoutを無効化し、BatchNormの統計量を固定するため)
    model.【問4】()

    # 【問5】勾配計算の無効化
    # 評価時は勾配計算が不要。メモリ節約と高速化のためにこれを使います。
    with torch.【問5】():
        for inputs, targets in dataloader:
            # ...推論処理...
答えと解説を見る

【問4】 eval

解説: model.eval()
推論モードへの切り替えです。


【問5】 no_grad

解説: torch.no_grad()
勾配計算を無効化し、メモリを節約します。

【完全版】WideResNet 全体コード

これまでのドリルで学んだ内容を結合し、実際に動作するクラスとして組み上げた完全なコードです。
import 文から動作確認用の main ブロックまで全て含まれています。

使い方: 右上のコピーボタン等でコードをコピーし、Python環境で実行してください。
出力として、ネットワークの出力サイズ(例: torch.Size([2, 10]))が表示されれば成功です。
import torch
import torch.nn as nn
import torch.nn.functional as F

# --------------------------------------------------------
# 1. ブロックの定義 (ドリル第1問・第2問の答え)
# --------------------------------------------------------
class WideBasicBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super(WideBasicBlock, self).__init__()
        
        # 1層目: 入力ch -> 出力ch
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        
        # 2層目: 出力ch -> 出力ch (サイズは変えない)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # ショートカット(次元合わせ)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        # 論文によっては BN->ReLU->Conv の順序(Pre-activation)もありますが、
        # ここでは一般的な ResNet (Post-activation) の構成で記述しています。
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        # Skip Connection: 元の入力(x)を足し合わせる
        out += self.shortcut(x)
        
        out = F.relu(out)
        return out

# --------------------------------------------------------
# 2. メインモデルの定義 (WideResNet本体)
# --------------------------------------------------------
class WideResNet(nn.Module):
    def __init__(self, depth=28, widen_factor=10, num_classes=10):
        super(WideResNet, self).__init__()
        self.in_planes = 16
        
        # ネットワークの深さから、各ステージのブロック数を計算
        # depth = conv1 + 3 * n * 2(block) + linear 
        # (depth - 4) / 6 という計算式が一般的です
        assert (depth - 4) % 6 == 0, 'Depth should be 6n+4'
        n = (depth - 4) // 6
        
        k = widen_factor
        
        # --- 層の構築 ---
        # 最初の畳み込み
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        
        # 3つのステージ (Stage 1, 2, 3)
        # チャンネル数が 16 -> 16*k -> 32*k -> 64*k と増えていく
        self.layer1 = self._make_layer(16*k, n, stride=1)
        self.layer2 = self._make_layer(32*k, n, stride=2)
        self.layer3 = self._make_layer(64*k, n, stride=2)
        
        # 最後の全結合層
        self.linear = nn.Linear(64*k, num_classes)

    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(WideBasicBlock(self.in_planes, planes, stride))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        # [Batch, 3, 32, 32]
        out = F.relu(self.bn1(self.conv1(x)))
        
        # 各ステージを通す
        out = self.layer1(out) # -> [Batch, 16*k, 32, 32]
        out = self.layer2(out) # -> [Batch, 32*k, 16, 16]
        out = self.layer3(out) # -> [Batch, 64*k, 8, 8]
        
        # Global Average Pooling
        out = F.avg_pool2d(out, 8) # -> [Batch, 64*k, 1, 1]
        out = out.view(out.size(0), -1) # Flatten
        
        # 全結合層
        out = self.linear(out)
        return out

# --------------------------------------------------------
# 3. 動作確認用ブロック (実行部分)
# --------------------------------------------------------
def test_run():
    # モデルのインスタンス化 (軽量な設定: depth=16, k=2)
    net = WideResNet(depth=16, widen_factor=2, num_classes=10)
    
    # ダミー入力データの作成 (Batch=2, RGB, 32x32)
    x = torch.randn(2, 3, 32, 32)
    
    # 評価モードへ (ドリル第5問の答え)
    net.eval()
    
    # 推論実行 (ドリル第5問の答え)
    with torch.no_grad():
        y = net(x)
        
    print("-" * 30)
    print(f"Input  Shape: {x.shape}")
    print(f"Output Shape: {y.shape}") # 予想: [2, 10]
    print("-" * 30)
    print("WideResNetの構築と順伝播に成功しました!")

if __name__ == "__main__":
    test_run()

【コラム】Pythonで構造図を自動生成する

モデルの構造を理解するには、図を描くのが一番です。
実は、Pythonのライブラリ graphviz を使えば、WideResNetのブロック図をコードから生成することができます。
学習の記録やブログ作成に役立つので、興味がある方は試してみてください。

※実行には pip install graphvizGraphviz本体のインストールが必要です。
from graphviz import Digraph

def draw_wideresnet_block():
    dot = Digraph(format='png')
    dot.attr(rankdir='TB') # 上から下へ

    # ノード定義
    dot.node('Input', 'Input\n(16ch)', shape='box')
    dot.node('BN1', 'BN + ReLU', shape='box')
    dot.node('Conv1', 'Conv 3x3\n(32ch)', shape='box', style='filled', fillcolor='lightblue')
    dot.node('BN2', 'BN + ReLU', shape='box')
    dot.node('Conv2', 'Conv 3x3\n(32ch)', shape='box', style='filled', fillcolor='lightblue')
    dot.node('Add', 'Add (+)', shape='circle')
    dot.node('Output', 'Output', shape='box')

    # メインパスの接続
    dot.edge('Input', 'BN1')
    dot.edge('BN1', 'Conv1')
    dot.edge('Conv1', 'BN2')
    dot.edge('BN2', 'Conv2')
    dot.edge('Conv2', 'Add')
    dot.edge('Add', 'Output')

    # ショートカットパス(1x1 Convで次元合わせ)
    dot.node('Shortcut', 'Conv 1x1\n(32ch)', shape='box', style='filled', fillcolor='lightyellow')
    dot.edge('Input', 'Shortcut', label=' Skip Connection')
    dot.edge('Shortcut', 'Add')

    dot.render('wideresnet_block', view=True)

draw_wideresnet_block()

今回のドリルで、モデルの「形」と「動き」の両方がイメージできたかと思います。
次回は実際のデータセットを使って学習させてみましょう。