ブログ連載:AIの「文脈を読む力」の秘密 - Attention機構 完全解説
第2回
【専門編】RNNとAttention:Seq2Seqモデルの性能を最大化する仕組み
前回の【入門編】では、Attentionが「情報のボトルネック」問題を解決するために、AIが翻訳する文章の必要な部分をその都度「カンニング」する画期的なアイデアであることを解説しました。
今回は【専門編】として、その仕組みをコードと共に深く掘り下げます。このRNNベースのAttentionは、専門的には「交差注意機構(Cross-Attention)」の一種に分類されます。これは、デコーダ(訳文を作る側)が、エンコーダ(原文を読む側)という2つの異なる情報源を「交差」して参照することから名付けられました。
本記事では、このCross-Attentionがどのように計算され、Seq2Seqモデルの性能を飛躍的に向上させたのかを、PyTorchによる実装コードを通して詳細に解説します。
Attentionのアーキテクチャ
Attentionメカニズムは、デコーダが各タイムステップで単語を出力する際に、エンコーダのすべての隠れ状態(encoder_outputs)を参照し、動的にコンテキストベクトルを生成する機構です。
キーとなる3つの要素をまず押さえましょう。
encoder_outputs- 入力系列の各単語に対応する、エンコーダの全隠れ状態のシーケンスです。デコーダがいつでも参照できる「情報源(リファレンス)」の役割を果たします。
decoder_hidden- デコーダの現在のタイムステップにおける隠れ状態です。今、情報源から何を探すべきかを示す「検索クエリ」の役割を持ちます。
context_vector- 上記の
encoder_outputs(情報源)を、後述するattention_weightsで加重平均したベクトルです。検索クエリ(decoder_hidden)に基づいて、その時点で最も関連性の高い情報だけを抽出した、「オーダーメイドの要約」です。
Attentionの計算プロセスと実装
Attentionの計算は、主に「①スコア計算」「②重み計算」「③コンテキストベクトル計算」の3つのステップで構成されます。以下に、PyTorchのnn.ModuleとしてAttention層を実装する例を示します。
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(nn.Module):
def __init__(self, hidden_size):
super(Attention, self).__init__()
# スコア計算のために学習される全結合層
self.attn = nn.Linear(hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.rand(hidden_size))
def forward(self, decoder_hidden, encoder_outputs):
"""
Args:
decoder_hidden (Tensor): デコーダーの現在の隠れ状態 [1, batch_size, hidden_size]
encoder_outputs (Tensor): エンコーダーの全出力 [seq_len, batch_size, hidden_size]
"""
seq_len = encoder_outputs.size(0)
# decoder_hiddenをencoder_outputsの長さに合わせて複製
# [1, B, H] -> [S, B, H]
hidden_repeated = decoder_hidden.repeat(seq_len, 1, 1)
# --- ① スコア計算 (Alignment Score) ---
# concatしたベクトルを線形層とtanhに通し、エネルギーを計算
# [S, B, H*2] -> [S, B, H]
attn_inputs = torch.cat((hidden_repeated, encoder_outputs), dim=2)
energy = torch.tanh(self.attn(attn_inputs))
# エネルギーと学習可能なベクトルvの内積で最終的なスコアを計算
# [S, B, H] -> [S, B]
energy = energy.permute(1, 0, 2) # [B, S, H]
v_permuted = self.v.repeat(encoder_outputs.size(1), 1).unsqueeze(1) # [B, 1, H]
attn_scores = torch.bmm(v_permuted, energy.transpose(1, 2)).squeeze(1) # [B, S]
# --- ② 重み計算 (Attention Weights) ---
# スコアをsoftmaxで正規化し、確率分布(重み)に変換
return F.softmax(attn_scores, dim=1)
# --- 実行例 ---
hidden_size = 128
seq_len = 10
batch_size = 32
# モデルとダミーデータの準備
attention_layer = Attention(hidden_size)
decoder_hidden = torch.randn(1, batch_size, hidden_size)
encoder_outputs = torch.randn(seq_len, batch_size, hidden_size)
# Attention Weightsの計算
attention_weights = attention_layer(decoder_hidden, encoder_outputs)
# --- ③ コンテキストベクトル計算 (Context Vector) ---
# attention_weights [B, S] を使ってencoder_outputs [S, B, H] の加重平均を計算
# まず次元を合わせる [B, 1, S]
weights_reshaped = attention_weights.unsqueeze(1)
# encoder_outputsの次元を入れ替える [B, S, H]
encoder_outputs_permuted = encoder_outputs.permute(1, 0, 2)
# バッチごとに行列積を計算
context_vector = torch.bmm(weights_reshaped, encoder_outputs_permuted) # [B, 1, H]
print(f"Attention Weights Shape: {attention_weights.shape}") # -> [32, 10]
print(f"Context Vector Shape: {context_vector.shape}") # -> [32, 1, 128]
デコーダーのループへの統合
このAttentionモジュールは、デコーダーが単語を一つずつ生成するループの中で、以下のように使われます。
- 現在の
decoder_hiddenをクエリとして、Attentionモジュールを呼び出し、attention_weightsを得る。 - 得られた
attention_weightsとencoder_outputsで加重平均をとり、context_vectorを計算する。 - この
context_vectorと、一つ前のタイムステップで予測した単語の埋め込みベクトルを結合(torch.cat)する。 - 結合したベクトルをデコーダーのRNNセル(GRU/LSTM)に入力し、次の
decoder_hiddenを計算する。 - 同時に、結合したベクトルを出力層(
nn.Linear->F.log_softmax)に入力し、現在のタイムステップで出力する単語を予測する。 - 文末記号が出力されるまで、このループを繰り返す。