月影

日々の雑感

【PyTorchで実装】RNNとAttention:Seq2Seqモデルの性能を最大化する仕組みを解説

 

ブログ連載:AIの「文脈を読む力」の秘密 - Attention機構 完全解説

第2回

【専門編】RNNとAttention:Seq2Seqモデルの性能を最大化する仕組み

前回の【入門編】では、Attentionが「情報のボトルネック」問題を解決するために、AIが翻訳する文章の必要な部分をその都度「カンニング」する画期的なアイデアであることを解説しました。

今回は【専門編】として、その仕組みをコードと共に深く掘り下げます。このRNNベースのAttentionは、専門的には「交差注意機構(Cross-Attention)」の一種に分類されます。これは、デコーダ(訳文を作る側)が、エンコーダ(原文を読む側)という2つの異なる情報源を「交差」して参照することから名付けられました。

本記事では、このCross-Attentionがどのように計算され、Seq2Seqモデルの性能を飛躍的に向上させたのかを、PyTorchによる実装コードを通して詳細に解説します。

Attentionのアーキテクチャ

Attentionメカニズムは、デコーダが各タイムステップで単語を出力する際に、エンコーダのすべての隠れ状態(encoder_outputs)を参照し、動的にコンテキストベクトルを生成する機構です。

キーとなる3つの要素をまず押さえましょう。

encoder_outputs
入力系列の各単語に対応する、エンコーダの全隠れ状態のシーケンスです。デコーダがいつでも参照できる「情報源(リファレンス)」の役割を果たします。
decoder_hidden
デコーダ現在のタイムステップにおける隠れ状態です。今、情報源から何を探すべきかを示す「検索クエリ」の役割を持ちます。
context_vector
上記のencoder_outputs(情報源)を、後述するattention_weightsで加重平均したベクトルです。検索クエリ(decoder_hidden)に基づいて、その時点で最も関連性の高い情報だけを抽出した、「オーダーメイドの要約」です。

Attentionの計算プロセスと実装

Attentionの計算は、主に「①スコア計算」「②重み計算」「③コンテキストベクトル計算」の3つのステップで構成されます。以下に、PyTorchのnn.ModuleとしてAttention層を実装する例を示します。


import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        # スコア計算のために学習される全結合層
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))

    def forward(self, decoder_hidden, encoder_outputs):
        """
        Args:
            decoder_hidden (Tensor): デコーダーの現在の隠れ状態 [1, batch_size, hidden_size]
            encoder_outputs (Tensor): エンコーダーの全出力 [seq_len, batch_size, hidden_size]
        """
        seq_len = encoder_outputs.size(0)
        
        # decoder_hiddenをencoder_outputsの長さに合わせて複製
        # [1, B, H] -> [S, B, H]
        hidden_repeated = decoder_hidden.repeat(seq_len, 1, 1)
        
        # --- ① スコア計算 (Alignment Score) ---
        # concatしたベクトルを線形層とtanhに通し、エネルギーを計算
        # [S, B, H*2] -> [S, B, H]
        attn_inputs = torch.cat((hidden_repeated, encoder_outputs), dim=2)
        energy = torch.tanh(self.attn(attn_inputs))
        
        # エネルギーと学習可能なベクトルvの内積で最終的なスコアを計算
        # [S, B, H] -> [S, B]
        energy = energy.permute(1, 0, 2) # [B, S, H]
        v_permuted = self.v.repeat(encoder_outputs.size(1), 1).unsqueeze(1) # [B, 1, H]
        attn_scores = torch.bmm(v_permuted, energy.transpose(1, 2)).squeeze(1) # [B, S]

        # --- ② 重み計算 (Attention Weights) ---
        # スコアをsoftmaxで正規化し、確率分布(重み)に変換
        return F.softmax(attn_scores, dim=1)

# --- 実行例 ---
hidden_size = 128
seq_len = 10
batch_size = 32

# モデルとダミーデータの準備
attention_layer = Attention(hidden_size)
decoder_hidden = torch.randn(1, batch_size, hidden_size)
encoder_outputs = torch.randn(seq_len, batch_size, hidden_size)

# Attention Weightsの計算
attention_weights = attention_layer(decoder_hidden, encoder_outputs)

# --- ③ コンテキストベクトル計算 (Context Vector) ---
# attention_weights [B, S] を使ってencoder_outputs [S, B, H] の加重平均を計算
# まず次元を合わせる [B, 1, S]
weights_reshaped = attention_weights.unsqueeze(1)
# encoder_outputsの次元を入れ替える [B, S, H]
encoder_outputs_permuted = encoder_outputs.permute(1, 0, 2)
# バッチごとに行列積を計算
context_vector = torch.bmm(weights_reshaped, encoder_outputs_permuted) # [B, 1, H]

print(f"Attention Weights Shape: {attention_weights.shape}") # -> [32, 10]
print(f"Context Vector Shape: {context_vector.shape}")  # -> [32, 1, 128]
                

デコーダーのループへの統合

このAttentionモジュールは、デコーダーが単語を一つずつ生成するループの中で、以下のように使われます。

  1. 現在のdecoder_hiddenをクエリとして、Attentionモジュールを呼び出し、attention_weightsを得る。
  2. 得られたattention_weightsencoder_outputsで加重平均をとり、context_vectorを計算する。
  3. このcontext_vectorと、一つ前のタイムステップで予測した単語の埋め込みベクトルを結合(torch.cat)する。
  4. 結合したベクトルをデコーダーのRNNセル(GRU/LSTM)に入力し、次のdecoder_hiddenを計算する。
  5. 同時に、結合したベクトルを出力層(nn.Linear -> F.log_softmax)に入力し、現在のタイムステップで出力する単語を予測する。
  6. 文末記号が出力されるまで、このループを繰り返す。

まとめ

RNNベースのAttention(Cross-Attention)は、デコーダが各ステップで入力系列のどこに注目すべきかを動的に学習する強力なメカニズムです。

これにより、固定長のコンテキストベクトルという制約からモデルを解放し、特に長い系列に対するSeq2Seqの性能を劇的に向上させました。この「異なる情報源を相互参照する」というアイデアは、後のTransformerアーキテクチャにおけるSelf-Attentionの基礎ともなっており、現代のNLP技術を理解する上で不可欠な要素と言えるでしょう。

次回は、いよいよ現代のLLMの基礎であるTransformerと、その心臓部である「自己注意機構(Self-Attention)」の解説に入ります。